From 1851fa37252c70844d1ee43f347488809ff24ecf Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Mon, 2 Jun 2025 17:33:42 +0200 Subject: [PATCH 01/11] proper repo id for `bert-base-uncased` --- packages/hub/src/lib/file-download-info.spec.ts | 6 +++--- packages/hub/src/lib/file-exists.spec.ts | 4 ++-- packages/hub/src/lib/list-files.spec.ts | 4 ++-- packages/hub/src/lib/parse-safetensors-metadata.spec.ts | 2 +- packages/hub/src/lib/paths-info.spec.ts | 6 +++--- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/packages/hub/src/lib/file-download-info.spec.ts b/packages/hub/src/lib/file-download-info.spec.ts index d2be156626..0a18cc7573 100644 --- a/packages/hub/src/lib/file-download-info.spec.ts +++ b/packages/hub/src/lib/file-download-info.spec.ts @@ -5,7 +5,7 @@ describe("fileDownloadInfo", () => { it("should fetch LFS file info", async () => { const info = await fileDownloadInfo({ repo: { - name: "bert-base-uncased", + name: "google-bert/bert-base-uncased", type: "model", }, path: "tf_model.h5", @@ -19,7 +19,7 @@ describe("fileDownloadInfo", () => { it("should fetch raw LFS pointer info", async () => { const info = await fileDownloadInfo({ repo: { - name: "bert-base-uncased", + name: "google-bert/bert-base-uncased", type: "model", }, path: "tf_model.h5", @@ -34,7 +34,7 @@ describe("fileDownloadInfo", () => { it("should fetch non-LFS file info", async () => { const info = await fileDownloadInfo({ repo: { - name: "bert-base-uncased", + name: "google-bert/bert-base-uncased", type: "model", }, path: "tokenizer_config.json", diff --git a/packages/hub/src/lib/file-exists.spec.ts b/packages/hub/src/lib/file-exists.spec.ts index e20acdf3bc..54c8ccd90e 100644 --- a/packages/hub/src/lib/file-exists.spec.ts +++ b/packages/hub/src/lib/file-exists.spec.ts @@ -5,7 +5,7 @@ describe("fileExists", () => { it("should return true for file that exists", async () => { const info = await fileExists({ repo: { - name: "bert-base-uncased", + name: "google-bert/bert-base-uncased", type: "model", }, path: "tf_model.h5", @@ -18,7 +18,7 @@ describe("fileExists", () => { it("should return false for file that does not exist", async () => { const info = await fileExists({ repo: { - name: "bert-base-uncased", + name: "google-bert/bert-base-uncased", type: "model", }, path: "tf_model.h5dadazdzazd", diff --git a/packages/hub/src/lib/list-files.spec.ts b/packages/hub/src/lib/list-files.spec.ts index 7014193075..00d3777de8 100644 --- a/packages/hub/src/lib/list-files.spec.ts +++ b/packages/hub/src/lib/list-files.spec.ts @@ -6,7 +6,7 @@ describe("listFiles", () => { it("should fetch the list of files from the repo", async () => { const cursor = listFiles({ repo: { - name: "bert-base-uncased", + name: "google-bert/bert-base-uncased", type: "model", }, revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7", @@ -67,7 +67,7 @@ describe("listFiles", () => { it("should fetch the list of files from the repo, including last commit", async () => { const cursor = listFiles({ repo: { - name: "bert-base-uncased", + name: "google-bert/bert-base-uncased", type: "model", }, revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7", diff --git a/packages/hub/src/lib/parse-safetensors-metadata.spec.ts b/packages/hub/src/lib/parse-safetensors-metadata.spec.ts index d96f5ed650..ca7fa234a1 100644 --- a/packages/hub/src/lib/parse-safetensors-metadata.spec.ts +++ b/packages/hub/src/lib/parse-safetensors-metadata.spec.ts @@ -5,7 +5,7 @@ import { sum } from "../utils/sum"; describe("parseSafetensorsMetadata", () => { it("fetch info for single-file (with the default conventional filename)", async () => { const parse = await parseSafetensorsMetadata({ - repo: "bert-base-uncased", + repo: "google-bert/bert-base-uncased", computeParametersCount: true, revision: "86b5e0934494bd15c9632b12f734a8a67f723594", }); diff --git a/packages/hub/src/lib/paths-info.spec.ts b/packages/hub/src/lib/paths-info.spec.ts index 837f4a1924..84219b66e7 100644 --- a/packages/hub/src/lib/paths-info.spec.ts +++ b/packages/hub/src/lib/paths-info.spec.ts @@ -6,7 +6,7 @@ describe("pathsInfo", () => { it("should fetch LFS path info", async () => { const result: PathInfo[] = await pathsInfo({ repo: { - name: "bert-base-uncased", + name: "google-bert/bert-base-uncased", type: "model", }, paths: ["tf_model.h5"], @@ -35,7 +35,7 @@ describe("pathsInfo", () => { securityFileStatus: SecurityFileStatus; })[] = await pathsInfo({ repo: { - name: "bert-base-uncased", + name: "google-bert/bert-base-uncased", type: "model", }, paths: ["tf_model.h5"], @@ -59,7 +59,7 @@ describe("pathsInfo", () => { it("non-LFS pointer should have lfs undefined", async () => { const result: PathInfo[] = await pathsInfo({ repo: { - name: "bert-base-uncased", + name: "google-bert/bert-base-uncased", type: "model", }, paths: ["config.json"], From 210fc47b7368caed33cb01b7549ef88ec6009346 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Mon, 2 Jun 2025 17:57:18 +0200 Subject: [PATCH 02/11] improve coding style --- packages/hub/src/lib/parse-safetensors-metadata.ts | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/packages/hub/src/lib/parse-safetensors-metadata.ts b/packages/hub/src/lib/parse-safetensors-metadata.ts index ca43a00883..bf5b70d49f 100644 --- a/packages/hub/src/lib/parse-safetensors-metadata.ts +++ b/packages/hub/src/lib/parse-safetensors-metadata.ts @@ -189,12 +189,12 @@ export async function parseSafetensorsMetadata( params: { /** Only models are supported */ repo: RepoDesignation; + path?: string; /** * Will include SafetensorsParseFromRepo["parameterCount"], an object containing the number of parameters for each DType * * @default false */ - path?: string; computeParametersCount?: boolean; hubUrl?: string; revision?: string; @@ -223,7 +223,10 @@ export async function parseSafetensorsMetadata( throw new TypeError("Only model repos should contain safetensors files."); } - if (RE_SAFETENSORS_FILE.test(params.path ?? "") || (await fileExists({ ...params, path: SAFETENSORS_FILE }))) { + if ( + (params.path && RE_SAFETENSORS_FILE.test(params.path)) || + (await fileExists({ ...params, path: SAFETENSORS_FILE })) + ) { const header = await parseSingleFile(params.path ?? SAFETENSORS_FILE, params); return { sharded: false, @@ -233,7 +236,7 @@ export async function parseSafetensorsMetadata( }), }; } else if ( - RE_SAFETENSORS_INDEX_FILE.test(params.path ?? "") || + (params.path && RE_SAFETENSORS_INDEX_FILE.test(params.path)) || (await fileExists({ ...params, path: SAFETENSORS_INDEX_FILE })) ) { const { index, headers } = await parseShardedIndex(params.path ?? SAFETENSORS_INDEX_FILE, params); From a648319db343ae8aa5d7a25ea6d7e7bc3b444849 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Mon, 2 Jun 2025 18:09:27 +0200 Subject: [PATCH 03/11] WE DO NOT USE THIS && BULLSHIT --- .../hub/src/lib/parse-safetensors-metadata.ts | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/packages/hub/src/lib/parse-safetensors-metadata.ts b/packages/hub/src/lib/parse-safetensors-metadata.ts index bf5b70d49f..f9bdc8c3e4 100644 --- a/packages/hub/src/lib/parse-safetensors-metadata.ts +++ b/packages/hub/src/lib/parse-safetensors-metadata.ts @@ -231,9 +231,11 @@ export async function parseSafetensorsMetadata( return { sharded: false, header, - ...(params.computeParametersCount && { - parameterCount: computeNumOfParamsByDtypeSingleFile(header), - }), + ...(params.computeParametersCount + ? { + parameterCount: computeNumOfParamsByDtypeSingleFile(header), + } + : undefined), }; } else if ( (params.path && RE_SAFETENSORS_INDEX_FILE.test(params.path)) || @@ -244,9 +246,11 @@ export async function parseSafetensorsMetadata( sharded: true, index, headers, - ...(params.computeParametersCount && { - parameterCount: computeNumOfParamsByDtypeSharded(headers), - }), + ...(params.computeParametersCount + ? { + parameterCount: computeNumOfParamsByDtypeSharded(headers), + } + : undefined), }; } else { throw new Error("model id does not seem to contain safetensors weights"); From cbdcdeafc39f632e5c2638867857e86af7ca5935 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Mon, 2 Jun 2025 18:13:46 +0200 Subject: [PATCH 04/11] Add a few dtypes --- .../hub/src/lib/parse-safetensors-metadata.ts | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/packages/hub/src/lib/parse-safetensors-metadata.ts b/packages/hub/src/lib/parse-safetensors-metadata.ts index f9bdc8c3e4..d1cef7342d 100644 --- a/packages/hub/src/lib/parse-safetensors-metadata.ts +++ b/packages/hub/src/lib/parse-safetensors-metadata.ts @@ -42,7 +42,21 @@ class SafetensorParseError extends Error {} type FileName = string; export type TensorName = string; -export type Dtype = "F64" | "F32" | "F16" | "BF16" | "I64" | "I32" | "I16" | "I8" | "U8" | "BOOL"; +export type Dtype = + | "F64" + | "F32" + | "F16" + | "F8_E4M3" + | "F8_E5M2" + | "BF16" + | "I64" + | "I32" + | "I16" + | "I8" + | "U16" + | "U8" + | "BOOL" + | "UNK"; /// when the total_parameters is stored directly in the header, we use this dummy dtype export interface TensorInfo { dtype: Dtype; From 6e02e357ac7291564387688fb5767a56e7f61837 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Mon, 2 Jun 2025 18:14:31 +0200 Subject: [PATCH 05/11] implem --- .../hub/src/lib/parse-safetensors-metadata.ts | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/packages/hub/src/lib/parse-safetensors-metadata.ts b/packages/hub/src/lib/parse-safetensors-metadata.ts index d1cef7342d..dda9e78f68 100644 --- a/packages/hub/src/lib/parse-safetensors-metadata.ts +++ b/packages/hub/src/lib/parse-safetensors-metadata.ts @@ -65,13 +65,13 @@ export interface TensorInfo { } export type SafetensorsFileHeader = Record & { - __metadata__: Record; + __metadata__: { total_parameters?: string | number } & Record; }; export interface SafetensorsIndexJson { dtype?: string; /// ^there's sometimes a dtype but it looks inconsistent. - metadata?: Record; + metadata?: { total_parameters?: string | number } & Record; /// ^ why the naming inconsistency? weight_map: Record; } @@ -256,13 +256,14 @@ export async function parseSafetensorsMetadata( (await fileExists({ ...params, path: SAFETENSORS_INDEX_FILE })) ) { const { index, headers } = await parseShardedIndex(params.path ?? SAFETENSORS_INDEX_FILE, params); + return { sharded: true, index, headers, ...(params.computeParametersCount ? { - parameterCount: computeNumOfParamsByDtypeSharded(headers), + parameterCount: computeNumOfParamsByDtypeSharded(index, headers), } : undefined), }; @@ -272,6 +273,10 @@ export async function parseSafetensorsMetadata( } function computeNumOfParamsByDtypeSingleFile(header: SafetensorsFileHeader): Partial> { + if (header.__metadata__.total_parameters) { + /// shortcut: get param count directly from metadata + return { UNK: parseInt(header.__metadata__.total_parameters.toString()) }; + } const counter: Partial> = {}; const tensors = omit(header, "__metadata__"); @@ -284,7 +289,14 @@ function computeNumOfParamsByDtypeSingleFile(header: SafetensorsFileHeader): Par return counter; } -function computeNumOfParamsByDtypeSharded(shardedMap: SafetensorsShardedHeaders): Partial> { +function computeNumOfParamsByDtypeSharded( + index: SafetensorsIndexJson, + shardedMap: SafetensorsShardedHeaders +): Partial> { + if (index.metadata?.total_parameters) { + /// shortcut: get param count directly from metadata + return { UNK: parseInt(index.metadata.total_parameters.toString()) }; + } const counter: Partial> = {}; for (const header of Object.values(shardedMap)) { for (const [k, v] of typedEntries(computeNumOfParamsByDtypeSingleFile(header))) { From 19b4ac238c8ca3491b0b596307f8065f1aca43ba Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Mon, 2 Jun 2025 18:24:26 +0200 Subject: [PATCH 06/11] implement the tests --- .../lib/parse-safetensors-metadata.spec.ts | 26 ++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/packages/hub/src/lib/parse-safetensors-metadata.spec.ts b/packages/hub/src/lib/parse-safetensors-metadata.spec.ts index ca7fa234a1..318162dac9 100644 --- a/packages/hub/src/lib/parse-safetensors-metadata.spec.ts +++ b/packages/hub/src/lib/parse-safetensors-metadata.spec.ts @@ -88,7 +88,7 @@ describe("parseSafetensorsMetadata", () => { assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 859_520_964); }); - it("fetch info for sharded (with the default conventional filename) with file path", async () => { + it("fetch info for sharded with file path", async () => { const parse = await parseSafetensorsMetadata({ repo: "Alignment-Lab-AI/ALAI-gemma-7b", computeParametersCount: true, @@ -110,6 +110,30 @@ describe("parseSafetensorsMetadata", () => { assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 8_537_680_896); }); + it("fetch info for sharded, but get param count directly from metadata", async () => { + const parse = await parseSafetensorsMetadata({ + repo: "hf-internal-testing/sharded-model-metadata-num-parameters", + computeParametersCount: true, + revision: "999395eb3db277f3d7a0393402b02486ca91cef8", + }); + + assert(parse.sharded); + assert.deepStrictEqual(parse.parameterCount, { UNK: 109_482_240 }); + // total params = 109M + }); + + it.skip("fetch info for single-file, but get param count directly from metadata", async () => { + /// we don't have an example for this on the Hub yet... cc @LysandreJik + const parse = await parseSafetensorsMetadata({ + repo: "hf-internal-testing/non-sharded-model", + computeParametersCount: true, + revision: "ce6373360e61e6f70b4a1e0cfcc9407b008dea5b", + }); + + assert(!parse.sharded); + assert.deepStrictEqual(parse.parameterCount, { UNK: 666 }); + }); + it("should detect sharded safetensors filename", async () => { const safetensorsFilename = "model_00005-of-00072.safetensors"; // https://huggingface.co/bigscience/bloom/blob/4d8e28c67403974b0f17a4ac5992e4ba0b0dbb6f/model_00005-of-00072.safetensors const safetensorsShardFileInfo = parseSafetensorsShardFilename(safetensorsFilename); From 68dd4c69b0e20dcc99d35f87b9858fa0657af2b6 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Mon, 2 Jun 2025 18:47:51 +0200 Subject: [PATCH 07/11] shortcut like crazy --- .../hub/src/lib/parse-safetensors-metadata.ts | 63 ++++++++++++++----- 1 file changed, 47 insertions(+), 16 deletions(-) diff --git a/packages/hub/src/lib/parse-safetensors-metadata.ts b/packages/hub/src/lib/parse-safetensors-metadata.ts index dda9e78f68..a4eb7f725c 100644 --- a/packages/hub/src/lib/parse-safetensors-metadata.ts +++ b/packages/hub/src/lib/parse-safetensors-metadata.ts @@ -141,7 +141,7 @@ async function parseShardedIndex( */ fetch?: typeof fetch; } & Partial -): Promise<{ index: SafetensorsIndexJson; headers: SafetensorsShardedHeaders }> { +): Promise { const indexBlob = await downloadFile({ ...params, path, @@ -151,14 +151,28 @@ async function parseShardedIndex( throw new SafetensorParseError(`Failed to parse file ${path}: failed to fetch safetensors index.`); } - // no validation for now, we assume it's a valid IndexJson. - let index: SafetensorsIndexJson; try { - index = JSON.parse(await indexBlob.slice(0, 10_000_000).text()); + // no validation for now, we assume it's a valid IndexJson. + const index = JSON.parse(await indexBlob.slice(0, 10_000_000).text()); + return index; } catch (error) { throw new SafetensorParseError(`Failed to parse file ${path}: not a valid JSON.`); } +} +async function fetchAllHeaders( + path: string, + index: SafetensorsIndexJson, + params: { + repo: RepoDesignation; + revision?: string; + hubUrl?: string; + /** + * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. + */ + fetch?: typeof fetch; + } & Partial +): Promise { const pathPrefix = path.slice(0, path.lastIndexOf("/") + 1); const filenames = [...new Set(Object.values(index.weight_map))]; const shardedMap: SafetensorsShardedHeaders = Object.fromEntries( @@ -170,7 +184,7 @@ async function parseShardedIndex( PARALLEL_DOWNLOADS ) ); - return { index, headers: shardedMap }; + return shardedMap; } /** @@ -191,6 +205,7 @@ export async function parseSafetensorsMetadata( * @default false */ computeParametersCount: true; + fetchAllHeaders?: boolean; hubUrl?: string; revision?: string; /** @@ -210,6 +225,12 @@ export async function parseSafetensorsMetadata( * @default false */ computeParametersCount?: boolean; + /** + * Always fetch all headers (no shortcut) + * + * @default false + */ + fetchAllHeaders?: boolean; hubUrl?: string; revision?: string; /** @@ -223,6 +244,7 @@ export async function parseSafetensorsMetadata( repo: RepoDesignation; path?: string; computeParametersCount?: boolean; + fetchAllHeaders?: boolean; hubUrl?: string; revision?: string; /** @@ -255,15 +277,31 @@ export async function parseSafetensorsMetadata( (params.path && RE_SAFETENSORS_INDEX_FILE.test(params.path)) || (await fileExists({ ...params, path: SAFETENSORS_INDEX_FILE })) ) { - const { index, headers } = await parseShardedIndex(params.path ?? SAFETENSORS_INDEX_FILE, params); + const path = params.path ?? SAFETENSORS_INDEX_FILE; + const index = await parseShardedIndex(path, params); + + const shardedMap = + params.fetchAllHeaders || (params.computeParametersCount && !index.metadata?.total_parameters) + ? await fetchAllHeaders(path, index, params) + : {}; + + if (params.computeParametersCount && index.metadata?.total_parameters) { + /// shortcut: get param count directly from metadata + return { + sharded: true, + index, + headers: shardedMap, + parameterCount: { UNK: parseInt(index.metadata.total_parameters.toString()) }, + }; + } return { sharded: true, index, - headers, + headers: shardedMap, ...(params.computeParametersCount ? { - parameterCount: computeNumOfParamsByDtypeSharded(index, headers), + parameterCount: computeNumOfParamsByDtypeSharded(shardedMap), } : undefined), }; @@ -289,14 +327,7 @@ function computeNumOfParamsByDtypeSingleFile(header: SafetensorsFileHeader): Par return counter; } -function computeNumOfParamsByDtypeSharded( - index: SafetensorsIndexJson, - shardedMap: SafetensorsShardedHeaders -): Partial> { - if (index.metadata?.total_parameters) { - /// shortcut: get param count directly from metadata - return { UNK: parseInt(index.metadata.total_parameters.toString()) }; - } +function computeNumOfParamsByDtypeSharded(shardedMap: SafetensorsShardedHeaders): Partial> { const counter: Partial> = {}; for (const header of Object.values(shardedMap)) { for (const [k, v] of typedEntries(computeNumOfParamsByDtypeSingleFile(header))) { From b3c059dee9c7cb206908a1ab36ae920191d37f39 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Mon, 2 Jun 2025 19:15:09 +0200 Subject: [PATCH 08/11] nvm, let's do it much simpler --- .../lib/parse-safetensors-metadata.spec.ts | 6 +-- .../hub/src/lib/parse-safetensors-metadata.ts | 41 ++++++------------- 2 files changed, 15 insertions(+), 32 deletions(-) diff --git a/packages/hub/src/lib/parse-safetensors-metadata.spec.ts b/packages/hub/src/lib/parse-safetensors-metadata.spec.ts index 318162dac9..bfdefb5c4f 100644 --- a/packages/hub/src/lib/parse-safetensors-metadata.spec.ts +++ b/packages/hub/src/lib/parse-safetensors-metadata.spec.ts @@ -110,7 +110,7 @@ describe("parseSafetensorsMetadata", () => { assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 8_537_680_896); }); - it("fetch info for sharded, but get param count directly from metadata", async () => { + it.only("fetch info for sharded, but get param count directly from metadata", async () => { const parse = await parseSafetensorsMetadata({ repo: "hf-internal-testing/sharded-model-metadata-num-parameters", computeParametersCount: true, @@ -118,7 +118,7 @@ describe("parseSafetensorsMetadata", () => { }); assert(parse.sharded); - assert.deepStrictEqual(parse.parameterCount, { UNK: 109_482_240 }); + assert.deepStrictEqual(parse.parameterTotal, 109_482_240); // total params = 109M }); @@ -131,7 +131,7 @@ describe("parseSafetensorsMetadata", () => { }); assert(!parse.sharded); - assert.deepStrictEqual(parse.parameterCount, { UNK: 666 }); + assert.deepStrictEqual(parse.parameterTotal, 666); }); it("should detect sharded safetensors filename", async () => { diff --git a/packages/hub/src/lib/parse-safetensors-metadata.ts b/packages/hub/src/lib/parse-safetensors-metadata.ts index a4eb7f725c..e593588933 100644 --- a/packages/hub/src/lib/parse-safetensors-metadata.ts +++ b/packages/hub/src/lib/parse-safetensors-metadata.ts @@ -55,8 +55,7 @@ export type Dtype = | "I8" | "U16" | "U8" - | "BOOL" - | "UNK"; /// when the total_parameters is stored directly in the header, we use this dummy dtype + | "BOOL"; export interface TensorInfo { dtype: Dtype; @@ -83,12 +82,14 @@ export type SafetensorsParseFromRepo = sharded: false; header: SafetensorsFileHeader; parameterCount?: Partial>; + parameterTotal?: number; } | { sharded: true; index: SafetensorsIndexJson; headers: SafetensorsShardedHeaders; parameterCount?: Partial>; + parameterTotal?: number; }; async function parseSingleFile( @@ -205,7 +206,6 @@ export async function parseSafetensorsMetadata( * @default false */ computeParametersCount: true; - fetchAllHeaders?: boolean; hubUrl?: string; revision?: string; /** @@ -225,12 +225,6 @@ export async function parseSafetensorsMetadata( * @default false */ computeParametersCount?: boolean; - /** - * Always fetch all headers (no shortcut) - * - * @default false - */ - fetchAllHeaders?: boolean; hubUrl?: string; revision?: string; /** @@ -244,7 +238,6 @@ export async function parseSafetensorsMetadata( repo: RepoDesignation; path?: string; computeParametersCount?: boolean; - fetchAllHeaders?: boolean; hubUrl?: string; revision?: string; /** @@ -270,6 +263,11 @@ export async function parseSafetensorsMetadata( ...(params.computeParametersCount ? { parameterCount: computeNumOfParamsByDtypeSingleFile(header), + parameterTotal: + /// shortcut: get param count directly from metadata + header.__metadata__.total_parameters + ? parseInt(header.__metadata__.total_parameters.toString()) + : undefined, } : undefined), }; @@ -279,21 +277,7 @@ export async function parseSafetensorsMetadata( ) { const path = params.path ?? SAFETENSORS_INDEX_FILE; const index = await parseShardedIndex(path, params); - - const shardedMap = - params.fetchAllHeaders || (params.computeParametersCount && !index.metadata?.total_parameters) - ? await fetchAllHeaders(path, index, params) - : {}; - - if (params.computeParametersCount && index.metadata?.total_parameters) { - /// shortcut: get param count directly from metadata - return { - sharded: true, - index, - headers: shardedMap, - parameterCount: { UNK: parseInt(index.metadata.total_parameters.toString()) }, - }; - } + const shardedMap = await fetchAllHeaders(path, index, params); return { sharded: true, @@ -302,6 +286,9 @@ export async function parseSafetensorsMetadata( ...(params.computeParametersCount ? { parameterCount: computeNumOfParamsByDtypeSharded(shardedMap), + parameterTotal: + /// shortcut: get param count directly from metadata + index.metadata?.total_parameters ? parseInt(index.metadata.total_parameters.toString()) : undefined, } : undefined), }; @@ -311,10 +298,6 @@ export async function parseSafetensorsMetadata( } function computeNumOfParamsByDtypeSingleFile(header: SafetensorsFileHeader): Partial> { - if (header.__metadata__.total_parameters) { - /// shortcut: get param count directly from metadata - return { UNK: parseInt(header.__metadata__.total_parameters.toString()) }; - } const counter: Partial> = {}; const tensors = omit(header, "__metadata__"); From 97de8bca2e5811037b442ddb2b4f3f89b1f64930 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Mon, 2 Jun 2025 19:22:50 +0200 Subject: [PATCH 09/11] Update parse-safetensors-metadata.spec.ts --- packages/hub/src/lib/parse-safetensors-metadata.spec.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/hub/src/lib/parse-safetensors-metadata.spec.ts b/packages/hub/src/lib/parse-safetensors-metadata.spec.ts index bfdefb5c4f..9bbf9b6a73 100644 --- a/packages/hub/src/lib/parse-safetensors-metadata.spec.ts +++ b/packages/hub/src/lib/parse-safetensors-metadata.spec.ts @@ -110,7 +110,7 @@ describe("parseSafetensorsMetadata", () => { assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 8_537_680_896); }); - it.only("fetch info for sharded, but get param count directly from metadata", async () => { + it("fetch info for sharded, but get param count directly from metadata", async () => { const parse = await parseSafetensorsMetadata({ repo: "hf-internal-testing/sharded-model-metadata-num-parameters", computeParametersCount: true, From c0fc9f3e3803db52ffcf31013eba5e791b054728 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Tue, 3 Jun 2025 09:57:29 +0200 Subject: [PATCH 10/11] Harden parameterTotal parsing a bit --- packages/hub/src/lib/parse-safetensors-metadata.ts | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/packages/hub/src/lib/parse-safetensors-metadata.ts b/packages/hub/src/lib/parse-safetensors-metadata.ts index e593588933..f67f7fffc5 100644 --- a/packages/hub/src/lib/parse-safetensors-metadata.ts +++ b/packages/hub/src/lib/parse-safetensors-metadata.ts @@ -266,7 +266,11 @@ export async function parseSafetensorsMetadata( parameterTotal: /// shortcut: get param count directly from metadata header.__metadata__.total_parameters - ? parseInt(header.__metadata__.total_parameters.toString()) + ? typeof header.__metadata__.total_parameters === "number" + ? header.__metadata__.total_parameters + : typeof header.__metadata__.total_parameters === "string" + ? parseInt(header.__metadata__.total_parameters) + : undefined : undefined, } : undefined), @@ -288,7 +292,13 @@ export async function parseSafetensorsMetadata( parameterCount: computeNumOfParamsByDtypeSharded(shardedMap), parameterTotal: /// shortcut: get param count directly from metadata - index.metadata?.total_parameters ? parseInt(index.metadata.total_parameters.toString()) : undefined, + index.metadata?.total_parameters + ? typeof index.metadata.total_parameters === "number" + ? index.metadata.total_parameters + : typeof index.metadata.total_parameters === "string" + ? parseInt(index.metadata.total_parameters) + : undefined + : undefined, } : undefined), }; From e1177fc2657be967dd1f8b1c9045c7fdbab16caf Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Tue, 3 Jun 2025 14:05:40 +0200 Subject: [PATCH 11/11] @LysandreJik added an actual model --- packages/hub/src/lib/parse-safetensors-metadata.spec.ts | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/packages/hub/src/lib/parse-safetensors-metadata.spec.ts b/packages/hub/src/lib/parse-safetensors-metadata.spec.ts index 9bbf9b6a73..57f1fc94bd 100644 --- a/packages/hub/src/lib/parse-safetensors-metadata.spec.ts +++ b/packages/hub/src/lib/parse-safetensors-metadata.spec.ts @@ -122,16 +122,15 @@ describe("parseSafetensorsMetadata", () => { // total params = 109M }); - it.skip("fetch info for single-file, but get param count directly from metadata", async () => { - /// we don't have an example for this on the Hub yet... cc @LysandreJik + it("fetch info for single-file, but get param count directly from metadata", async () => { const parse = await parseSafetensorsMetadata({ - repo: "hf-internal-testing/non-sharded-model", + repo: "hf-internal-testing/single-file-model", computeParametersCount: true, - revision: "ce6373360e61e6f70b4a1e0cfcc9407b008dea5b", + revision: "75fcd3fed0285ac7f1092897ff2aefdf24bf872e", }); assert(!parse.sharded); - assert.deepStrictEqual(parse.parameterTotal, 666); + assert.deepStrictEqual(parse.parameterTotal, 109_482_240); }); it("should detect sharded safetensors filename", async () => {