Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions packages/hub/src/lib/file-download-info.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ describe("fileDownloadInfo", () => {
it("should fetch LFS file info", async () => {
const info = await fileDownloadInfo({
repo: {
name: "bert-base-uncased",
name: "google-bert/bert-base-uncased",
type: "model",
},
path: "tf_model.h5",
Expand All @@ -19,7 +19,7 @@ describe("fileDownloadInfo", () => {
it("should fetch raw LFS pointer info", async () => {
const info = await fileDownloadInfo({
repo: {
name: "bert-base-uncased",
name: "google-bert/bert-base-uncased",
type: "model",
},
path: "tf_model.h5",
Expand All @@ -34,7 +34,7 @@ describe("fileDownloadInfo", () => {
it("should fetch non-LFS file info", async () => {
const info = await fileDownloadInfo({
repo: {
name: "bert-base-uncased",
name: "google-bert/bert-base-uncased",
type: "model",
},
path: "tokenizer_config.json",
Expand Down
4 changes: 2 additions & 2 deletions packages/hub/src/lib/file-exists.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ describe("fileExists", () => {
it("should return true for file that exists", async () => {
const info = await fileExists({
repo: {
name: "bert-base-uncased",
name: "google-bert/bert-base-uncased",
type: "model",
},
path: "tf_model.h5",
Expand All @@ -18,7 +18,7 @@ describe("fileExists", () => {
it("should return false for file that does not exist", async () => {
const info = await fileExists({
repo: {
name: "bert-base-uncased",
name: "google-bert/bert-base-uncased",
type: "model",
},
path: "tf_model.h5dadazdzazd",
Expand Down
4 changes: 2 additions & 2 deletions packages/hub/src/lib/list-files.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ describe("listFiles", () => {
it("should fetch the list of files from the repo", async () => {
const cursor = listFiles({
repo: {
name: "bert-base-uncased",
name: "google-bert/bert-base-uncased",
type: "model",
},
revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7",
Expand Down Expand Up @@ -67,7 +67,7 @@ describe("listFiles", () => {
it("should fetch the list of files from the repo, including last commit", async () => {
const cursor = listFiles({
repo: {
name: "bert-base-uncased",
name: "google-bert/bert-base-uncased",
type: "model",
},
revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7",
Expand Down
27 changes: 25 additions & 2 deletions packages/hub/src/lib/parse-safetensors-metadata.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { sum } from "../utils/sum";
describe("parseSafetensorsMetadata", () => {
it("fetch info for single-file (with the default conventional filename)", async () => {
const parse = await parseSafetensorsMetadata({
repo: "bert-base-uncased",
repo: "google-bert/bert-base-uncased",
computeParametersCount: true,
revision: "86b5e0934494bd15c9632b12f734a8a67f723594",
});
Expand Down Expand Up @@ -88,7 +88,7 @@ describe("parseSafetensorsMetadata", () => {
assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 859_520_964);
});

it("fetch info for sharded (with the default conventional filename) with file path", async () => {
it("fetch info for sharded with file path", async () => {
const parse = await parseSafetensorsMetadata({
repo: "Alignment-Lab-AI/ALAI-gemma-7b",
computeParametersCount: true,
Expand All @@ -110,6 +110,29 @@ describe("parseSafetensorsMetadata", () => {
assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 8_537_680_896);
});

it("fetch info for sharded, but get param count directly from metadata", async () => {
const parse = await parseSafetensorsMetadata({
repo: "hf-internal-testing/sharded-model-metadata-num-parameters",
computeParametersCount: true,
revision: "999395eb3db277f3d7a0393402b02486ca91cef8",
});

assert(parse.sharded);
assert.deepStrictEqual(parse.parameterTotal, 109_482_240);
// total params = 109M
});

it("fetch info for single-file, but get param count directly from metadata", async () => {
const parse = await parseSafetensorsMetadata({
repo: "hf-internal-testing/single-file-model",
computeParametersCount: true,
revision: "75fcd3fed0285ac7f1092897ff2aefdf24bf872e",
});

assert(!parse.sharded);
assert.deepStrictEqual(parse.parameterTotal, 109_482_240);
});

it("should detect sharded safetensors filename", async () => {
const safetensorsFilename = "model_00005-of-00072.safetensors"; // https://huggingface.co/bigscience/bloom/blob/4d8e28c67403974b0f17a4ac5992e4ba0b0dbb6f/model_00005-of-00072.safetensors
const safetensorsShardFileInfo = parseSafetensorsShardFilename(safetensorsFilename);
Expand Down
95 changes: 76 additions & 19 deletions packages/hub/src/lib/parse-safetensors-metadata.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,20 @@ class SafetensorParseError extends Error {}
type FileName = string;

export type TensorName = string;
export type Dtype = "F64" | "F32" | "F16" | "BF16" | "I64" | "I32" | "I16" | "I8" | "U8" | "BOOL";
export type Dtype =
| "F64"
| "F32"
| "F16"
| "F8_E4M3"
| "F8_E5M2"
| "BF16"
| "I64"
| "I32"
| "I16"
| "I8"
| "U16"
| "U8"
| "BOOL";

export interface TensorInfo {
dtype: Dtype;
Expand All @@ -51,13 +64,13 @@ export interface TensorInfo {
}

export type SafetensorsFileHeader = Record<TensorName, TensorInfo> & {
__metadata__: Record<string, string>;
__metadata__: { total_parameters?: string | number } & Record<string, string>;
};

export interface SafetensorsIndexJson {
dtype?: string;
/// ^there's sometimes a dtype but it looks inconsistent.
metadata?: Record<string, string>;
metadata?: { total_parameters?: string | number } & Record<string, string>;
/// ^ why the naming inconsistency?
weight_map: Record<TensorName, FileName>;
}
Expand All @@ -69,12 +82,14 @@ export type SafetensorsParseFromRepo =
sharded: false;
header: SafetensorsFileHeader;
parameterCount?: Partial<Record<Dtype, number>>;
parameterTotal?: number;
}
| {
sharded: true;
index: SafetensorsIndexJson;
headers: SafetensorsShardedHeaders;
parameterCount?: Partial<Record<Dtype, number>>;
parameterTotal?: number;
};

async function parseSingleFile(
Expand Down Expand Up @@ -127,7 +142,7 @@ async function parseShardedIndex(
*/
fetch?: typeof fetch;
} & Partial<CredentialsParams>
): Promise<{ index: SafetensorsIndexJson; headers: SafetensorsShardedHeaders }> {
): Promise<SafetensorsIndexJson> {
const indexBlob = await downloadFile({
...params,
path,
Expand All @@ -137,14 +152,28 @@ async function parseShardedIndex(
throw new SafetensorParseError(`Failed to parse file ${path}: failed to fetch safetensors index.`);
}

// no validation for now, we assume it's a valid IndexJson.
let index: SafetensorsIndexJson;
try {
index = JSON.parse(await indexBlob.slice(0, 10_000_000).text());
// no validation for now, we assume it's a valid IndexJson.
const index = JSON.parse(await indexBlob.slice(0, 10_000_000).text());
return index;
} catch (error) {
throw new SafetensorParseError(`Failed to parse file ${path}: not a valid JSON.`);
}
}

async function fetchAllHeaders(
path: string,
index: SafetensorsIndexJson,
params: {
repo: RepoDesignation;
revision?: string;
hubUrl?: string;
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
fetch?: typeof fetch;
} & Partial<CredentialsParams>
): Promise<SafetensorsShardedHeaders> {
const pathPrefix = path.slice(0, path.lastIndexOf("/") + 1);
const filenames = [...new Set(Object.values(index.weight_map))];
const shardedMap: SafetensorsShardedHeaders = Object.fromEntries(
Expand All @@ -156,7 +185,7 @@ async function parseShardedIndex(
PARALLEL_DOWNLOADS
)
);
return { index, headers: shardedMap };
return shardedMap;
}

/**
Expand Down Expand Up @@ -189,12 +218,12 @@ export async function parseSafetensorsMetadata(
params: {
/** Only models are supported */
repo: RepoDesignation;
path?: string;
/**
* Will include SafetensorsParseFromRepo["parameterCount"], an object containing the number of parameters for each DType
*
* @default false
*/
path?: string;
computeParametersCount?: boolean;
hubUrl?: string;
revision?: string;
Expand Down Expand Up @@ -223,27 +252,55 @@ export async function parseSafetensorsMetadata(
throw new TypeError("Only model repos should contain safetensors files.");
}

if (RE_SAFETENSORS_FILE.test(params.path ?? "") || (await fileExists({ ...params, path: SAFETENSORS_FILE }))) {
if (
(params.path && RE_SAFETENSORS_FILE.test(params.path)) ||
(await fileExists({ ...params, path: SAFETENSORS_FILE }))
) {
const header = await parseSingleFile(params.path ?? SAFETENSORS_FILE, params);
return {
sharded: false,
header,
...(params.computeParametersCount && {
parameterCount: computeNumOfParamsByDtypeSingleFile(header),
}),
...(params.computeParametersCount
? {
parameterCount: computeNumOfParamsByDtypeSingleFile(header),
parameterTotal:
/// shortcut: get param count directly from metadata
header.__metadata__.total_parameters
? typeof header.__metadata__.total_parameters === "number"
? header.__metadata__.total_parameters
: typeof header.__metadata__.total_parameters === "string"
? parseInt(header.__metadata__.total_parameters)
: undefined
: undefined,
}
: undefined),
};
} else if (
RE_SAFETENSORS_INDEX_FILE.test(params.path ?? "") ||
(params.path && RE_SAFETENSORS_INDEX_FILE.test(params.path)) ||
(await fileExists({ ...params, path: SAFETENSORS_INDEX_FILE }))
) {
const { index, headers } = await parseShardedIndex(params.path ?? SAFETENSORS_INDEX_FILE, params);
const path = params.path ?? SAFETENSORS_INDEX_FILE;
const index = await parseShardedIndex(path, params);
const shardedMap = await fetchAllHeaders(path, index, params);

return {
sharded: true,
index,
headers,
...(params.computeParametersCount && {
parameterCount: computeNumOfParamsByDtypeSharded(headers),
}),
headers: shardedMap,
...(params.computeParametersCount
? {
parameterCount: computeNumOfParamsByDtypeSharded(shardedMap),
parameterTotal:
/// shortcut: get param count directly from metadata
index.metadata?.total_parameters
? typeof index.metadata.total_parameters === "number"
? index.metadata.total_parameters
: typeof index.metadata.total_parameters === "string"
? parseInt(index.metadata.total_parameters)
: undefined
: undefined,
}
: undefined),
};
} else {
throw new Error("model id does not seem to contain safetensors weights");
Expand Down
6 changes: 3 additions & 3 deletions packages/hub/src/lib/paths-info.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ describe("pathsInfo", () => {
it("should fetch LFS path info", async () => {
const result: PathInfo[] = await pathsInfo({
repo: {
name: "bert-base-uncased",
name: "google-bert/bert-base-uncased",
type: "model",
},
paths: ["tf_model.h5"],
Expand Down Expand Up @@ -35,7 +35,7 @@ describe("pathsInfo", () => {
securityFileStatus: SecurityFileStatus;
})[] = await pathsInfo({
repo: {
name: "bert-base-uncased",
name: "google-bert/bert-base-uncased",
type: "model",
},
paths: ["tf_model.h5"],
Expand All @@ -59,7 +59,7 @@ describe("pathsInfo", () => {
it("non-LFS pointer should have lfs undefined", async () => {
const result: PathInfo[] = await pathsInfo({
repo: {
name: "bert-base-uncased",
name: "google-bert/bert-base-uncased",
type: "model",
},
paths: ["config.json"],
Expand Down
Loading