Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions packages/hub/src/lib/file-download-info.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ describe("fileDownloadInfo", () => {
it("should fetch LFS file info", async () => {
const info = await fileDownloadInfo({
repo: {
name: "bert-base-uncased",
name: "google-bert/bert-base-uncased",
type: "model",
},
path: "tf_model.h5",
Expand All @@ -19,7 +19,7 @@ describe("fileDownloadInfo", () => {
it("should fetch raw LFS pointer info", async () => {
const info = await fileDownloadInfo({
repo: {
name: "bert-base-uncased",
name: "google-bert/bert-base-uncased",
type: "model",
},
path: "tf_model.h5",
Expand All @@ -34,7 +34,7 @@ describe("fileDownloadInfo", () => {
it("should fetch non-LFS file info", async () => {
const info = await fileDownloadInfo({
repo: {
name: "bert-base-uncased",
name: "google-bert/bert-base-uncased",
type: "model",
},
path: "tokenizer_config.json",
Expand Down
4 changes: 2 additions & 2 deletions packages/hub/src/lib/file-exists.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ describe("fileExists", () => {
it("should return true for file that exists", async () => {
const info = await fileExists({
repo: {
name: "bert-base-uncased",
name: "google-bert/bert-base-uncased",
type: "model",
},
path: "tf_model.h5",
Expand All @@ -18,7 +18,7 @@ describe("fileExists", () => {
it("should return false for file that does not exist", async () => {
const info = await fileExists({
repo: {
name: "bert-base-uncased",
name: "google-bert/bert-base-uncased",
type: "model",
},
path: "tf_model.h5dadazdzazd",
Expand Down
4 changes: 2 additions & 2 deletions packages/hub/src/lib/list-files.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ describe("listFiles", () => {
it("should fetch the list of files from the repo", async () => {
const cursor = listFiles({
repo: {
name: "bert-base-uncased",
name: "google-bert/bert-base-uncased",
type: "model",
},
revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7",
Expand Down Expand Up @@ -67,7 +67,7 @@ describe("listFiles", () => {
it("should fetch the list of files from the repo, including last commit", async () => {
const cursor = listFiles({
repo: {
name: "bert-base-uncased",
name: "google-bert/bert-base-uncased",
type: "model",
},
revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7",
Expand Down
28 changes: 26 additions & 2 deletions packages/hub/src/lib/parse-safetensors-metadata.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { sum } from "../utils/sum";
describe("parseSafetensorsMetadata", () => {
it("fetch info for single-file (with the default conventional filename)", async () => {
const parse = await parseSafetensorsMetadata({
repo: "bert-base-uncased",
repo: "google-bert/bert-base-uncased",
computeParametersCount: true,
revision: "86b5e0934494bd15c9632b12f734a8a67f723594",
});
Expand Down Expand Up @@ -88,7 +88,7 @@ describe("parseSafetensorsMetadata", () => {
assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 859_520_964);
});

it("fetch info for sharded (with the default conventional filename) with file path", async () => {
it("fetch info for sharded with file path", async () => {
const parse = await parseSafetensorsMetadata({
repo: "Alignment-Lab-AI/ALAI-gemma-7b",
computeParametersCount: true,
Expand All @@ -110,6 +110,30 @@ describe("parseSafetensorsMetadata", () => {
assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 8_537_680_896);
});

it("fetch info for sharded, but get param count directly from metadata", async () => {
const parse = await parseSafetensorsMetadata({
repo: "hf-internal-testing/sharded-model-metadata-num-parameters",
computeParametersCount: true,
revision: "999395eb3db277f3d7a0393402b02486ca91cef8",
});

assert(parse.sharded);
assert.deepStrictEqual(parse.parameterTotal, 109_482_240);
// total params = 109M
});

it.skip("fetch info for single-file, but get param count directly from metadata", async () => {
/// we don't have an example for this on the Hub yet... cc @LysandreJik
const parse = await parseSafetensorsMetadata({
repo: "hf-internal-testing/non-sharded-model",
computeParametersCount: true,
revision: "ce6373360e61e6f70b4a1e0cfcc9407b008dea5b",
});

assert(!parse.sharded);
assert.deepStrictEqual(parse.parameterTotal, 666);
});

it("should detect sharded safetensors filename", async () => {
const safetensorsFilename = "model_00005-of-00072.safetensors"; // https://huggingface.co/bigscience/bloom/blob/4d8e28c67403974b0f17a4ac5992e4ba0b0dbb6f/model_00005-of-00072.safetensors
const safetensorsShardFileInfo = parseSafetensorsShardFilename(safetensorsFilename);
Expand Down
85 changes: 66 additions & 19 deletions packages/hub/src/lib/parse-safetensors-metadata.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,20 @@ class SafetensorParseError extends Error {}
type FileName = string;

export type TensorName = string;
export type Dtype = "F64" | "F32" | "F16" | "BF16" | "I64" | "I32" | "I16" | "I8" | "U8" | "BOOL";
export type Dtype =
| "F64"
| "F32"
| "F16"
| "F8_E4M3"
| "F8_E5M2"
| "BF16"
| "I64"
| "I32"
| "I16"
| "I8"
| "U16"
| "U8"
| "BOOL";

export interface TensorInfo {
dtype: Dtype;
Expand All @@ -51,13 +64,13 @@ export interface TensorInfo {
}

export type SafetensorsFileHeader = Record<TensorName, TensorInfo> & {
__metadata__: Record<string, string>;
__metadata__: { total_parameters?: string | number } & Record<string, string>;
};

export interface SafetensorsIndexJson {
dtype?: string;
/// ^there's sometimes a dtype but it looks inconsistent.
metadata?: Record<string, string>;
metadata?: { total_parameters?: string | number } & Record<string, string>;
/// ^ why the naming inconsistency?
weight_map: Record<TensorName, FileName>;
}
Expand All @@ -69,12 +82,14 @@ export type SafetensorsParseFromRepo =
sharded: false;
header: SafetensorsFileHeader;
parameterCount?: Partial<Record<Dtype, number>>;
parameterTotal?: number;
}
| {
sharded: true;
index: SafetensorsIndexJson;
headers: SafetensorsShardedHeaders;
parameterCount?: Partial<Record<Dtype, number>>;
parameterTotal?: number;
};

async function parseSingleFile(
Expand Down Expand Up @@ -127,7 +142,7 @@ async function parseShardedIndex(
*/
fetch?: typeof fetch;
} & Partial<CredentialsParams>
): Promise<{ index: SafetensorsIndexJson; headers: SafetensorsShardedHeaders }> {
): Promise<SafetensorsIndexJson> {
const indexBlob = await downloadFile({
...params,
path,
Expand All @@ -137,14 +152,28 @@ async function parseShardedIndex(
throw new SafetensorParseError(`Failed to parse file ${path}: failed to fetch safetensors index.`);
}

// no validation for now, we assume it's a valid IndexJson.
let index: SafetensorsIndexJson;
try {
index = JSON.parse(await indexBlob.slice(0, 10_000_000).text());
// no validation for now, we assume it's a valid IndexJson.
const index = JSON.parse(await indexBlob.slice(0, 10_000_000).text());
return index;
} catch (error) {
throw new SafetensorParseError(`Failed to parse file ${path}: not a valid JSON.`);
}
}

async function fetchAllHeaders(
path: string,
index: SafetensorsIndexJson,
params: {
repo: RepoDesignation;
revision?: string;
hubUrl?: string;
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
fetch?: typeof fetch;
} & Partial<CredentialsParams>
): Promise<SafetensorsShardedHeaders> {
const pathPrefix = path.slice(0, path.lastIndexOf("/") + 1);
const filenames = [...new Set(Object.values(index.weight_map))];
const shardedMap: SafetensorsShardedHeaders = Object.fromEntries(
Expand All @@ -156,7 +185,7 @@ async function parseShardedIndex(
PARALLEL_DOWNLOADS
)
);
return { index, headers: shardedMap };
return shardedMap;
}

/**
Expand Down Expand Up @@ -189,12 +218,12 @@ export async function parseSafetensorsMetadata(
params: {
/** Only models are supported */
repo: RepoDesignation;
path?: string;
/**
* Will include SafetensorsParseFromRepo["parameterCount"], an object containing the number of parameters for each DType
*
* @default false
*/
path?: string;
computeParametersCount?: boolean;
hubUrl?: string;
revision?: string;
Expand Down Expand Up @@ -223,27 +252,45 @@ export async function parseSafetensorsMetadata(
throw new TypeError("Only model repos should contain safetensors files.");
}

if (RE_SAFETENSORS_FILE.test(params.path ?? "") || (await fileExists({ ...params, path: SAFETENSORS_FILE }))) {
if (
(params.path && RE_SAFETENSORS_FILE.test(params.path)) ||
(await fileExists({ ...params, path: SAFETENSORS_FILE }))
) {
const header = await parseSingleFile(params.path ?? SAFETENSORS_FILE, params);
return {
sharded: false,
header,
...(params.computeParametersCount && {
parameterCount: computeNumOfParamsByDtypeSingleFile(header),
}),
...(params.computeParametersCount
? {
parameterCount: computeNumOfParamsByDtypeSingleFile(header),
parameterTotal:
/// shortcut: get param count directly from metadata
header.__metadata__.total_parameters
? parseInt(header.__metadata__.total_parameters.toString())
: undefined,
}
: undefined),
};
} else if (
RE_SAFETENSORS_INDEX_FILE.test(params.path ?? "") ||
(params.path && RE_SAFETENSORS_INDEX_FILE.test(params.path)) ||
(await fileExists({ ...params, path: SAFETENSORS_INDEX_FILE }))
) {
const { index, headers } = await parseShardedIndex(params.path ?? SAFETENSORS_INDEX_FILE, params);
const path = params.path ?? SAFETENSORS_INDEX_FILE;
const index = await parseShardedIndex(path, params);
const shardedMap = await fetchAllHeaders(path, index, params);

return {
sharded: true,
index,
headers,
...(params.computeParametersCount && {
parameterCount: computeNumOfParamsByDtypeSharded(headers),
}),
headers: shardedMap,
...(params.computeParametersCount
? {
parameterCount: computeNumOfParamsByDtypeSharded(shardedMap),
parameterTotal:
/// shortcut: get param count directly from metadata
index.metadata?.total_parameters ? parseInt(index.metadata.total_parameters.toString()) : undefined,
}
: undefined),
};
} else {
throw new Error("model id does not seem to contain safetensors weights");
Expand Down
6 changes: 3 additions & 3 deletions packages/hub/src/lib/paths-info.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ describe("pathsInfo", () => {
it("should fetch LFS path info", async () => {
const result: PathInfo[] = await pathsInfo({
repo: {
name: "bert-base-uncased",
name: "google-bert/bert-base-uncased",
type: "model",
},
paths: ["tf_model.h5"],
Expand Down Expand Up @@ -35,7 +35,7 @@ describe("pathsInfo", () => {
securityFileStatus: SecurityFileStatus;
})[] = await pathsInfo({
repo: {
name: "bert-base-uncased",
name: "google-bert/bert-base-uncased",
type: "model",
},
paths: ["tf_model.h5"],
Expand All @@ -59,7 +59,7 @@ describe("pathsInfo", () => {
it("non-LFS pointer should have lfs undefined", async () => {
const result: PathInfo[] = await pathsInfo({
repo: {
name: "bert-base-uncased",
name: "google-bert/bert-base-uncased",
type: "model",
},
paths: ["config.json"],
Expand Down
Loading