diff --git a/packages/hub/src/lib/file-download-info.spec.ts b/packages/hub/src/lib/file-download-info.spec.ts index d2be156626..0a18cc7573 100644 --- a/packages/hub/src/lib/file-download-info.spec.ts +++ b/packages/hub/src/lib/file-download-info.spec.ts @@ -5,7 +5,7 @@ describe("fileDownloadInfo", () => { it("should fetch LFS file info", async () => { const info = await fileDownloadInfo({ repo: { - name: "bert-base-uncased", + name: "google-bert/bert-base-uncased", type: "model", }, path: "tf_model.h5", @@ -19,7 +19,7 @@ describe("fileDownloadInfo", () => { it("should fetch raw LFS pointer info", async () => { const info = await fileDownloadInfo({ repo: { - name: "bert-base-uncased", + name: "google-bert/bert-base-uncased", type: "model", }, path: "tf_model.h5", @@ -34,7 +34,7 @@ describe("fileDownloadInfo", () => { it("should fetch non-LFS file info", async () => { const info = await fileDownloadInfo({ repo: { - name: "bert-base-uncased", + name: "google-bert/bert-base-uncased", type: "model", }, path: "tokenizer_config.json", diff --git a/packages/hub/src/lib/file-exists.spec.ts b/packages/hub/src/lib/file-exists.spec.ts index e20acdf3bc..54c8ccd90e 100644 --- a/packages/hub/src/lib/file-exists.spec.ts +++ b/packages/hub/src/lib/file-exists.spec.ts @@ -5,7 +5,7 @@ describe("fileExists", () => { it("should return true for file that exists", async () => { const info = await fileExists({ repo: { - name: "bert-base-uncased", + name: "google-bert/bert-base-uncased", type: "model", }, path: "tf_model.h5", @@ -18,7 +18,7 @@ describe("fileExists", () => { it("should return false for file that does not exist", async () => { const info = await fileExists({ repo: { - name: "bert-base-uncased", + name: "google-bert/bert-base-uncased", type: "model", }, path: "tf_model.h5dadazdzazd", diff --git a/packages/hub/src/lib/list-files.spec.ts b/packages/hub/src/lib/list-files.spec.ts index 7014193075..00d3777de8 100644 --- a/packages/hub/src/lib/list-files.spec.ts +++ b/packages/hub/src/lib/list-files.spec.ts @@ -6,7 +6,7 @@ describe("listFiles", () => { it("should fetch the list of files from the repo", async () => { const cursor = listFiles({ repo: { - name: "bert-base-uncased", + name: "google-bert/bert-base-uncased", type: "model", }, revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7", @@ -67,7 +67,7 @@ describe("listFiles", () => { it("should fetch the list of files from the repo, including last commit", async () => { const cursor = listFiles({ repo: { - name: "bert-base-uncased", + name: "google-bert/bert-base-uncased", type: "model", }, revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7", diff --git a/packages/hub/src/lib/parse-safetensors-metadata.spec.ts b/packages/hub/src/lib/parse-safetensors-metadata.spec.ts index d96f5ed650..57f1fc94bd 100644 --- a/packages/hub/src/lib/parse-safetensors-metadata.spec.ts +++ b/packages/hub/src/lib/parse-safetensors-metadata.spec.ts @@ -5,7 +5,7 @@ import { sum } from "../utils/sum"; describe("parseSafetensorsMetadata", () => { it("fetch info for single-file (with the default conventional filename)", async () => { const parse = await parseSafetensorsMetadata({ - repo: "bert-base-uncased", + repo: "google-bert/bert-base-uncased", computeParametersCount: true, revision: "86b5e0934494bd15c9632b12f734a8a67f723594", }); @@ -88,7 +88,7 @@ describe("parseSafetensorsMetadata", () => { assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 859_520_964); }); - it("fetch info for sharded (with the default conventional filename) with file path", async () => { + it("fetch info for sharded with file path", async () => { const parse = await parseSafetensorsMetadata({ repo: "Alignment-Lab-AI/ALAI-gemma-7b", computeParametersCount: true, @@ -110,6 +110,29 @@ describe("parseSafetensorsMetadata", () => { assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 8_537_680_896); }); + it("fetch info for sharded, but get param count directly from metadata", async () => { + const parse = await parseSafetensorsMetadata({ + repo: "hf-internal-testing/sharded-model-metadata-num-parameters", + computeParametersCount: true, + revision: "999395eb3db277f3d7a0393402b02486ca91cef8", + }); + + assert(parse.sharded); + assert.deepStrictEqual(parse.parameterTotal, 109_482_240); + // total params = 109M + }); + + it("fetch info for single-file, but get param count directly from metadata", async () => { + const parse = await parseSafetensorsMetadata({ + repo: "hf-internal-testing/single-file-model", + computeParametersCount: true, + revision: "75fcd3fed0285ac7f1092897ff2aefdf24bf872e", + }); + + assert(!parse.sharded); + assert.deepStrictEqual(parse.parameterTotal, 109_482_240); + }); + it("should detect sharded safetensors filename", async () => { const safetensorsFilename = "model_00005-of-00072.safetensors"; // https://huggingface.co/bigscience/bloom/blob/4d8e28c67403974b0f17a4ac5992e4ba0b0dbb6f/model_00005-of-00072.safetensors const safetensorsShardFileInfo = parseSafetensorsShardFilename(safetensorsFilename); diff --git a/packages/hub/src/lib/parse-safetensors-metadata.ts b/packages/hub/src/lib/parse-safetensors-metadata.ts index ca43a00883..f67f7fffc5 100644 --- a/packages/hub/src/lib/parse-safetensors-metadata.ts +++ b/packages/hub/src/lib/parse-safetensors-metadata.ts @@ -42,7 +42,20 @@ class SafetensorParseError extends Error {} type FileName = string; export type TensorName = string; -export type Dtype = "F64" | "F32" | "F16" | "BF16" | "I64" | "I32" | "I16" | "I8" | "U8" | "BOOL"; +export type Dtype = + | "F64" + | "F32" + | "F16" + | "F8_E4M3" + | "F8_E5M2" + | "BF16" + | "I64" + | "I32" + | "I16" + | "I8" + | "U16" + | "U8" + | "BOOL"; export interface TensorInfo { dtype: Dtype; @@ -51,13 +64,13 @@ export interface TensorInfo { } export type SafetensorsFileHeader = Record & { - __metadata__: Record; + __metadata__: { total_parameters?: string | number } & Record; }; export interface SafetensorsIndexJson { dtype?: string; /// ^there's sometimes a dtype but it looks inconsistent. - metadata?: Record; + metadata?: { total_parameters?: string | number } & Record; /// ^ why the naming inconsistency? weight_map: Record; } @@ -69,12 +82,14 @@ export type SafetensorsParseFromRepo = sharded: false; header: SafetensorsFileHeader; parameterCount?: Partial>; + parameterTotal?: number; } | { sharded: true; index: SafetensorsIndexJson; headers: SafetensorsShardedHeaders; parameterCount?: Partial>; + parameterTotal?: number; }; async function parseSingleFile( @@ -127,7 +142,7 @@ async function parseShardedIndex( */ fetch?: typeof fetch; } & Partial -): Promise<{ index: SafetensorsIndexJson; headers: SafetensorsShardedHeaders }> { +): Promise { const indexBlob = await downloadFile({ ...params, path, @@ -137,14 +152,28 @@ async function parseShardedIndex( throw new SafetensorParseError(`Failed to parse file ${path}: failed to fetch safetensors index.`); } - // no validation for now, we assume it's a valid IndexJson. - let index: SafetensorsIndexJson; try { - index = JSON.parse(await indexBlob.slice(0, 10_000_000).text()); + // no validation for now, we assume it's a valid IndexJson. + const index = JSON.parse(await indexBlob.slice(0, 10_000_000).text()); + return index; } catch (error) { throw new SafetensorParseError(`Failed to parse file ${path}: not a valid JSON.`); } +} +async function fetchAllHeaders( + path: string, + index: SafetensorsIndexJson, + params: { + repo: RepoDesignation; + revision?: string; + hubUrl?: string; + /** + * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. + */ + fetch?: typeof fetch; + } & Partial +): Promise { const pathPrefix = path.slice(0, path.lastIndexOf("/") + 1); const filenames = [...new Set(Object.values(index.weight_map))]; const shardedMap: SafetensorsShardedHeaders = Object.fromEntries( @@ -156,7 +185,7 @@ async function parseShardedIndex( PARALLEL_DOWNLOADS ) ); - return { index, headers: shardedMap }; + return shardedMap; } /** @@ -189,12 +218,12 @@ export async function parseSafetensorsMetadata( params: { /** Only models are supported */ repo: RepoDesignation; + path?: string; /** * Will include SafetensorsParseFromRepo["parameterCount"], an object containing the number of parameters for each DType * * @default false */ - path?: string; computeParametersCount?: boolean; hubUrl?: string; revision?: string; @@ -223,27 +252,55 @@ export async function parseSafetensorsMetadata( throw new TypeError("Only model repos should contain safetensors files."); } - if (RE_SAFETENSORS_FILE.test(params.path ?? "") || (await fileExists({ ...params, path: SAFETENSORS_FILE }))) { + if ( + (params.path && RE_SAFETENSORS_FILE.test(params.path)) || + (await fileExists({ ...params, path: SAFETENSORS_FILE })) + ) { const header = await parseSingleFile(params.path ?? SAFETENSORS_FILE, params); return { sharded: false, header, - ...(params.computeParametersCount && { - parameterCount: computeNumOfParamsByDtypeSingleFile(header), - }), + ...(params.computeParametersCount + ? { + parameterCount: computeNumOfParamsByDtypeSingleFile(header), + parameterTotal: + /// shortcut: get param count directly from metadata + header.__metadata__.total_parameters + ? typeof header.__metadata__.total_parameters === "number" + ? header.__metadata__.total_parameters + : typeof header.__metadata__.total_parameters === "string" + ? parseInt(header.__metadata__.total_parameters) + : undefined + : undefined, + } + : undefined), }; } else if ( - RE_SAFETENSORS_INDEX_FILE.test(params.path ?? "") || + (params.path && RE_SAFETENSORS_INDEX_FILE.test(params.path)) || (await fileExists({ ...params, path: SAFETENSORS_INDEX_FILE })) ) { - const { index, headers } = await parseShardedIndex(params.path ?? SAFETENSORS_INDEX_FILE, params); + const path = params.path ?? SAFETENSORS_INDEX_FILE; + const index = await parseShardedIndex(path, params); + const shardedMap = await fetchAllHeaders(path, index, params); + return { sharded: true, index, - headers, - ...(params.computeParametersCount && { - parameterCount: computeNumOfParamsByDtypeSharded(headers), - }), + headers: shardedMap, + ...(params.computeParametersCount + ? { + parameterCount: computeNumOfParamsByDtypeSharded(shardedMap), + parameterTotal: + /// shortcut: get param count directly from metadata + index.metadata?.total_parameters + ? typeof index.metadata.total_parameters === "number" + ? index.metadata.total_parameters + : typeof index.metadata.total_parameters === "string" + ? parseInt(index.metadata.total_parameters) + : undefined + : undefined, + } + : undefined), }; } else { throw new Error("model id does not seem to contain safetensors weights"); diff --git a/packages/hub/src/lib/paths-info.spec.ts b/packages/hub/src/lib/paths-info.spec.ts index 837f4a1924..84219b66e7 100644 --- a/packages/hub/src/lib/paths-info.spec.ts +++ b/packages/hub/src/lib/paths-info.spec.ts @@ -6,7 +6,7 @@ describe("pathsInfo", () => { it("should fetch LFS path info", async () => { const result: PathInfo[] = await pathsInfo({ repo: { - name: "bert-base-uncased", + name: "google-bert/bert-base-uncased", type: "model", }, paths: ["tf_model.h5"], @@ -35,7 +35,7 @@ describe("pathsInfo", () => { securityFileStatus: SecurityFileStatus; })[] = await pathsInfo({ repo: { - name: "bert-base-uncased", + name: "google-bert/bert-base-uncased", type: "model", }, paths: ["tf_model.h5"], @@ -59,7 +59,7 @@ describe("pathsInfo", () => { it("non-LFS pointer should have lfs undefined", async () => { const result: PathInfo[] = await pathsInfo({ repo: { - name: "bert-base-uncased", + name: "google-bert/bert-base-uncased", type: "model", }, paths: ["config.json"],