Skip to content

Commit 19b4ac2

Browse files
committed
implement the tests
1 parent 6e02e35 commit 19b4ac2

File tree

1 file changed

+25
-1
lines changed

1 file changed

+25
-1
lines changed

packages/hub/src/lib/parse-safetensors-metadata.spec.ts

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ describe("parseSafetensorsMetadata", () => {
8888
assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 859_520_964);
8989
});
9090

91-
it("fetch info for sharded (with the default conventional filename) with file path", async () => {
91+
it("fetch info for sharded with file path", async () => {
9292
const parse = await parseSafetensorsMetadata({
9393
repo: "Alignment-Lab-AI/ALAI-gemma-7b",
9494
computeParametersCount: true,
@@ -110,6 +110,30 @@ describe("parseSafetensorsMetadata", () => {
110110
assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 8_537_680_896);
111111
});
112112

113+
it("fetch info for sharded, but get param count directly from metadata", async () => {
114+
const parse = await parseSafetensorsMetadata({
115+
repo: "hf-internal-testing/sharded-model-metadata-num-parameters",
116+
computeParametersCount: true,
117+
revision: "999395eb3db277f3d7a0393402b02486ca91cef8",
118+
});
119+
120+
assert(parse.sharded);
121+
assert.deepStrictEqual(parse.parameterCount, { UNK: 109_482_240 });
122+
// total params = 109M
123+
});
124+
125+
it.skip("fetch info for single-file, but get param count directly from metadata", async () => {
126+
/// we don't have an example for this on the Hub yet... cc @LysandreJik
127+
const parse = await parseSafetensorsMetadata({
128+
repo: "hf-internal-testing/non-sharded-model",
129+
computeParametersCount: true,
130+
revision: "ce6373360e61e6f70b4a1e0cfcc9407b008dea5b",
131+
});
132+
133+
assert(!parse.sharded);
134+
assert.deepStrictEqual(parse.parameterCount, { UNK: 666 });
135+
});
136+
113137
it("should detect sharded safetensors filename", async () => {
114138
const safetensorsFilename = "model_00005-of-00072.safetensors"; // https://huggingface.co/bigscience/bloom/blob/4d8e28c67403974b0f17a4ac5992e4ba0b0dbb6f/model_00005-of-00072.safetensors
115139
const safetensorsShardFileInfo = parseSafetensorsShardFilename(safetensorsFilename);

0 commit comments

Comments
 (0)