Skip to content

Commit 68dd4c6

Browse files
committed
shortcut like crazy
1 parent 19b4ac2 commit 68dd4c6

File tree

1 file changed

+47
-16
lines changed

1 file changed

+47
-16
lines changed

packages/hub/src/lib/parse-safetensors-metadata.ts

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ async function parseShardedIndex(
141141
*/
142142
fetch?: typeof fetch;
143143
} & Partial<CredentialsParams>
144-
): Promise<{ index: SafetensorsIndexJson; headers: SafetensorsShardedHeaders }> {
144+
): Promise<SafetensorsIndexJson> {
145145
const indexBlob = await downloadFile({
146146
...params,
147147
path,
@@ -151,14 +151,28 @@ async function parseShardedIndex(
151151
throw new SafetensorParseError(`Failed to parse file ${path}: failed to fetch safetensors index.`);
152152
}
153153

154-
// no validation for now, we assume it's a valid IndexJson.
155-
let index: SafetensorsIndexJson;
156154
try {
157-
index = JSON.parse(await indexBlob.slice(0, 10_000_000).text());
155+
// no validation for now, we assume it's a valid IndexJson.
156+
const index = JSON.parse(await indexBlob.slice(0, 10_000_000).text());
157+
return index;
158158
} catch (error) {
159159
throw new SafetensorParseError(`Failed to parse file ${path}: not a valid JSON.`);
160160
}
161+
}
161162

163+
async function fetchAllHeaders(
164+
path: string,
165+
index: SafetensorsIndexJson,
166+
params: {
167+
repo: RepoDesignation;
168+
revision?: string;
169+
hubUrl?: string;
170+
/**
171+
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
172+
*/
173+
fetch?: typeof fetch;
174+
} & Partial<CredentialsParams>
175+
): Promise<SafetensorsShardedHeaders> {
162176
const pathPrefix = path.slice(0, path.lastIndexOf("/") + 1);
163177
const filenames = [...new Set(Object.values(index.weight_map))];
164178
const shardedMap: SafetensorsShardedHeaders = Object.fromEntries(
@@ -170,7 +184,7 @@ async function parseShardedIndex(
170184
PARALLEL_DOWNLOADS
171185
)
172186
);
173-
return { index, headers: shardedMap };
187+
return shardedMap;
174188
}
175189

176190
/**
@@ -191,6 +205,7 @@ export async function parseSafetensorsMetadata(
191205
* @default false
192206
*/
193207
computeParametersCount: true;
208+
fetchAllHeaders?: boolean;
194209
hubUrl?: string;
195210
revision?: string;
196211
/**
@@ -210,6 +225,12 @@ export async function parseSafetensorsMetadata(
210225
* @default false
211226
*/
212227
computeParametersCount?: boolean;
228+
/**
229+
* Always fetch all headers (no shortcut)
230+
*
231+
* @default false
232+
*/
233+
fetchAllHeaders?: boolean;
213234
hubUrl?: string;
214235
revision?: string;
215236
/**
@@ -223,6 +244,7 @@ export async function parseSafetensorsMetadata(
223244
repo: RepoDesignation;
224245
path?: string;
225246
computeParametersCount?: boolean;
247+
fetchAllHeaders?: boolean;
226248
hubUrl?: string;
227249
revision?: string;
228250
/**
@@ -255,15 +277,31 @@ export async function parseSafetensorsMetadata(
255277
(params.path && RE_SAFETENSORS_INDEX_FILE.test(params.path)) ||
256278
(await fileExists({ ...params, path: SAFETENSORS_INDEX_FILE }))
257279
) {
258-
const { index, headers } = await parseShardedIndex(params.path ?? SAFETENSORS_INDEX_FILE, params);
280+
const path = params.path ?? SAFETENSORS_INDEX_FILE;
281+
const index = await parseShardedIndex(path, params);
282+
283+
const shardedMap =
284+
params.fetchAllHeaders || (params.computeParametersCount && !index.metadata?.total_parameters)
285+
? await fetchAllHeaders(path, index, params)
286+
: {};
287+
288+
if (params.computeParametersCount && index.metadata?.total_parameters) {
289+
/// shortcut: get param count directly from metadata
290+
return {
291+
sharded: true,
292+
index,
293+
headers: shardedMap,
294+
parameterCount: { UNK: parseInt(index.metadata.total_parameters.toString()) },
295+
};
296+
}
259297

260298
return {
261299
sharded: true,
262300
index,
263-
headers,
301+
headers: shardedMap,
264302
...(params.computeParametersCount
265303
? {
266-
parameterCount: computeNumOfParamsByDtypeSharded(index, headers),
304+
parameterCount: computeNumOfParamsByDtypeSharded(shardedMap),
267305
}
268306
: undefined),
269307
};
@@ -289,14 +327,7 @@ function computeNumOfParamsByDtypeSingleFile(header: SafetensorsFileHeader): Par
289327
return counter;
290328
}
291329

292-
function computeNumOfParamsByDtypeSharded(
293-
index: SafetensorsIndexJson,
294-
shardedMap: SafetensorsShardedHeaders
295-
): Partial<Record<Dtype, number>> {
296-
if (index.metadata?.total_parameters) {
297-
/// shortcut: get param count directly from metadata
298-
return { UNK: parseInt(index.metadata.total_parameters.toString()) };
299-
}
330+
function computeNumOfParamsByDtypeSharded(shardedMap: SafetensorsShardedHeaders): Partial<Record<Dtype, number>> {
300331
const counter: Partial<Record<Dtype, number>> = {};
301332
for (const header of Object.values(shardedMap)) {
302333
for (const [k, v] of typedEntries(computeNumOfParamsByDtypeSingleFile(header))) {

0 commit comments

Comments
 (0)