Skip to content

Commit b3c059d

Browse files
committed
nvm, let's do it much simpler
1 parent 68dd4c6 commit b3c059d

File tree

2 files changed

+15
-32
lines changed

2 files changed

+15
-32
lines changed

packages/hub/src/lib/parse-safetensors-metadata.spec.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,15 +110,15 @@ describe("parseSafetensorsMetadata", () => {
110110
assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 8_537_680_896);
111111
});
112112

113-
it("fetch info for sharded, but get param count directly from metadata", async () => {
113+
it.only("fetch info for sharded, but get param count directly from metadata", async () => {
114114
const parse = await parseSafetensorsMetadata({
115115
repo: "hf-internal-testing/sharded-model-metadata-num-parameters",
116116
computeParametersCount: true,
117117
revision: "999395eb3db277f3d7a0393402b02486ca91cef8",
118118
});
119119

120120
assert(parse.sharded);
121-
assert.deepStrictEqual(parse.parameterCount, { UNK: 109_482_240 });
121+
assert.deepStrictEqual(parse.parameterTotal, 109_482_240);
122122
// total params = 109M
123123
});
124124

@@ -131,7 +131,7 @@ describe("parseSafetensorsMetadata", () => {
131131
});
132132

133133
assert(!parse.sharded);
134-
assert.deepStrictEqual(parse.parameterCount, { UNK: 666 });
134+
assert.deepStrictEqual(parse.parameterTotal, 666);
135135
});
136136

137137
it("should detect sharded safetensors filename", async () => {

packages/hub/src/lib/parse-safetensors-metadata.ts

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,7 @@ export type Dtype =
5555
| "I8"
5656
| "U16"
5757
| "U8"
58-
| "BOOL"
59-
| "UNK"; /// when the total_parameters is stored directly in the header, we use this dummy dtype
58+
| "BOOL";
6059

6160
export interface TensorInfo {
6261
dtype: Dtype;
@@ -83,12 +82,14 @@ export type SafetensorsParseFromRepo =
8382
sharded: false;
8483
header: SafetensorsFileHeader;
8584
parameterCount?: Partial<Record<Dtype, number>>;
85+
parameterTotal?: number;
8686
}
8787
| {
8888
sharded: true;
8989
index: SafetensorsIndexJson;
9090
headers: SafetensorsShardedHeaders;
9191
parameterCount?: Partial<Record<Dtype, number>>;
92+
parameterTotal?: number;
9293
};
9394

9495
async function parseSingleFile(
@@ -205,7 +206,6 @@ export async function parseSafetensorsMetadata(
205206
* @default false
206207
*/
207208
computeParametersCount: true;
208-
fetchAllHeaders?: boolean;
209209
hubUrl?: string;
210210
revision?: string;
211211
/**
@@ -225,12 +225,6 @@ export async function parseSafetensorsMetadata(
225225
* @default false
226226
*/
227227
computeParametersCount?: boolean;
228-
/**
229-
* Always fetch all headers (no shortcut)
230-
*
231-
* @default false
232-
*/
233-
fetchAllHeaders?: boolean;
234228
hubUrl?: string;
235229
revision?: string;
236230
/**
@@ -244,7 +238,6 @@ export async function parseSafetensorsMetadata(
244238
repo: RepoDesignation;
245239
path?: string;
246240
computeParametersCount?: boolean;
247-
fetchAllHeaders?: boolean;
248241
hubUrl?: string;
249242
revision?: string;
250243
/**
@@ -270,6 +263,11 @@ export async function parseSafetensorsMetadata(
270263
...(params.computeParametersCount
271264
? {
272265
parameterCount: computeNumOfParamsByDtypeSingleFile(header),
266+
parameterTotal:
267+
/// shortcut: get param count directly from metadata
268+
header.__metadata__.total_parameters
269+
? parseInt(header.__metadata__.total_parameters.toString())
270+
: undefined,
273271
}
274272
: undefined),
275273
};
@@ -279,21 +277,7 @@ export async function parseSafetensorsMetadata(
279277
) {
280278
const path = params.path ?? SAFETENSORS_INDEX_FILE;
281279
const index = await parseShardedIndex(path, params);
282-
283-
const shardedMap =
284-
params.fetchAllHeaders || (params.computeParametersCount && !index.metadata?.total_parameters)
285-
? await fetchAllHeaders(path, index, params)
286-
: {};
287-
288-
if (params.computeParametersCount && index.metadata?.total_parameters) {
289-
/// shortcut: get param count directly from metadata
290-
return {
291-
sharded: true,
292-
index,
293-
headers: shardedMap,
294-
parameterCount: { UNK: parseInt(index.metadata.total_parameters.toString()) },
295-
};
296-
}
280+
const shardedMap = await fetchAllHeaders(path, index, params);
297281

298282
return {
299283
sharded: true,
@@ -302,6 +286,9 @@ export async function parseSafetensorsMetadata(
302286
...(params.computeParametersCount
303287
? {
304288
parameterCount: computeNumOfParamsByDtypeSharded(shardedMap),
289+
parameterTotal:
290+
/// shortcut: get param count directly from metadata
291+
index.metadata?.total_parameters ? parseInt(index.metadata.total_parameters.toString()) : undefined,
305292
}
306293
: undefined),
307294
};
@@ -311,10 +298,6 @@ export async function parseSafetensorsMetadata(
311298
}
312299

313300
function computeNumOfParamsByDtypeSingleFile(header: SafetensorsFileHeader): Partial<Record<Dtype, number>> {
314-
if (header.__metadata__.total_parameters) {
315-
/// shortcut: get param count directly from metadata
316-
return { UNK: parseInt(header.__metadata__.total_parameters.toString()) };
317-
}
318301
const counter: Partial<Record<Dtype, number>> = {};
319302
const tensors = omit(header, "__metadata__");
320303

0 commit comments

Comments
 (0)