Skip to content

Commit 6e02e35

Browse files
committed
implem
1 parent cbdcdea commit 6e02e35

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

packages/hub/src/lib/parse-safetensors-metadata.ts

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,13 @@ export interface TensorInfo {
6565
}
6666

6767
export type SafetensorsFileHeader = Record<TensorName, TensorInfo> & {
68-
__metadata__: Record<string, string>;
68+
__metadata__: { total_parameters?: string | number } & Record<string, string>;
6969
};
7070

7171
export interface SafetensorsIndexJson {
7272
dtype?: string;
7373
/// ^there's sometimes a dtype but it looks inconsistent.
74-
metadata?: Record<string, string>;
74+
metadata?: { total_parameters?: string | number } & Record<string, string>;
7575
/// ^ why the naming inconsistency?
7676
weight_map: Record<TensorName, FileName>;
7777
}
@@ -256,13 +256,14 @@ export async function parseSafetensorsMetadata(
256256
(await fileExists({ ...params, path: SAFETENSORS_INDEX_FILE }))
257257
) {
258258
const { index, headers } = await parseShardedIndex(params.path ?? SAFETENSORS_INDEX_FILE, params);
259+
259260
return {
260261
sharded: true,
261262
index,
262263
headers,
263264
...(params.computeParametersCount
264265
? {
265-
parameterCount: computeNumOfParamsByDtypeSharded(headers),
266+
parameterCount: computeNumOfParamsByDtypeSharded(index, headers),
266267
}
267268
: undefined),
268269
};
@@ -272,6 +273,10 @@ export async function parseSafetensorsMetadata(
272273
}
273274

274275
function computeNumOfParamsByDtypeSingleFile(header: SafetensorsFileHeader): Partial<Record<Dtype, number>> {
276+
if (header.__metadata__.total_parameters) {
277+
/// shortcut: get param count directly from metadata
278+
return { UNK: parseInt(header.__metadata__.total_parameters.toString()) };
279+
}
275280
const counter: Partial<Record<Dtype, number>> = {};
276281
const tensors = omit(header, "__metadata__");
277282

@@ -284,7 +289,14 @@ function computeNumOfParamsByDtypeSingleFile(header: SafetensorsFileHeader): Par
284289
return counter;
285290
}
286291

287-
function computeNumOfParamsByDtypeSharded(shardedMap: SafetensorsShardedHeaders): Partial<Record<Dtype, number>> {
292+
function computeNumOfParamsByDtypeSharded(
293+
index: SafetensorsIndexJson,
294+
shardedMap: SafetensorsShardedHeaders
295+
): Partial<Record<Dtype, number>> {
296+
if (index.metadata?.total_parameters) {
297+
/// shortcut: get param count directly from metadata
298+
return { UNK: parseInt(index.metadata.total_parameters.toString()) };
299+
}
288300
const counter: Partial<Record<Dtype, number>> = {};
289301
for (const header of Object.values(shardedMap)) {
290302
for (const [k, v] of typedEntries(computeNumOfParamsByDtypeSingleFile(header))) {

0 commit comments

Comments
 (0)