@@ -55,8 +55,7 @@ export type Dtype =
5555 | "I8"
5656 | "U16"
5757 | "U8"
58- | "BOOL"
59- | "UNK" ; /// when the total_parameters is stored directly in the header, we use this dummy dtype
58+ | "BOOL" ;
6059
6160export interface TensorInfo {
6261 dtype : Dtype ;
@@ -83,12 +82,14 @@ export type SafetensorsParseFromRepo =
8382 sharded : false ;
8483 header : SafetensorsFileHeader ;
8584 parameterCount ?: Partial < Record < Dtype , number > > ;
85+ parameterTotal ?: number ;
8686 }
8787 | {
8888 sharded : true ;
8989 index : SafetensorsIndexJson ;
9090 headers : SafetensorsShardedHeaders ;
9191 parameterCount ?: Partial < Record < Dtype , number > > ;
92+ parameterTotal ?: number ;
9293 } ;
9394
9495async function parseSingleFile (
@@ -205,7 +206,6 @@ export async function parseSafetensorsMetadata(
205206 * @default false
206207 */
207208 computeParametersCount : true ;
208- fetchAllHeaders ?: boolean ;
209209 hubUrl ?: string ;
210210 revision ?: string ;
211211 /**
@@ -225,12 +225,6 @@ export async function parseSafetensorsMetadata(
225225 * @default false
226226 */
227227 computeParametersCount ?: boolean ;
228- /**
229- * Always fetch all headers (no shortcut)
230- *
231- * @default false
232- */
233- fetchAllHeaders ?: boolean ;
234228 hubUrl ?: string ;
235229 revision ?: string ;
236230 /**
@@ -244,7 +238,6 @@ export async function parseSafetensorsMetadata(
244238 repo : RepoDesignation ;
245239 path ?: string ;
246240 computeParametersCount ?: boolean ;
247- fetchAllHeaders ?: boolean ;
248241 hubUrl ?: string ;
249242 revision ?: string ;
250243 /**
@@ -270,6 +263,11 @@ export async function parseSafetensorsMetadata(
270263 ...( params . computeParametersCount
271264 ? {
272265 parameterCount : computeNumOfParamsByDtypeSingleFile ( header ) ,
266+ parameterTotal :
267+ /// shortcut: get param count directly from metadata
268+ header . __metadata__ . total_parameters
269+ ? parseInt ( header . __metadata__ . total_parameters . toString ( ) )
270+ : undefined ,
273271 }
274272 : undefined ) ,
275273 } ;
@@ -279,21 +277,7 @@ export async function parseSafetensorsMetadata(
279277 ) {
280278 const path = params . path ?? SAFETENSORS_INDEX_FILE ;
281279 const index = await parseShardedIndex ( path , params ) ;
282-
283- const shardedMap =
284- params . fetchAllHeaders || ( params . computeParametersCount && ! index . metadata ?. total_parameters )
285- ? await fetchAllHeaders ( path , index , params )
286- : { } ;
287-
288- if ( params . computeParametersCount && index . metadata ?. total_parameters ) {
289- /// shortcut: get param count directly from metadata
290- return {
291- sharded : true ,
292- index,
293- headers : shardedMap ,
294- parameterCount : { UNK : parseInt ( index . metadata . total_parameters . toString ( ) ) } ,
295- } ;
296- }
280+ const shardedMap = await fetchAllHeaders ( path , index , params ) ;
297281
298282 return {
299283 sharded : true ,
@@ -302,6 +286,9 @@ export async function parseSafetensorsMetadata(
302286 ...( params . computeParametersCount
303287 ? {
304288 parameterCount : computeNumOfParamsByDtypeSharded ( shardedMap ) ,
289+ parameterTotal :
290+ /// shortcut: get param count directly from metadata
291+ index . metadata ?. total_parameters ? parseInt ( index . metadata . total_parameters . toString ( ) ) : undefined ,
305292 }
306293 : undefined ) ,
307294 } ;
@@ -311,10 +298,6 @@ export async function parseSafetensorsMetadata(
311298}
312299
313300function computeNumOfParamsByDtypeSingleFile ( header : SafetensorsFileHeader ) : Partial < Record < Dtype , number > > {
314- if ( header . __metadata__ . total_parameters ) {
315- /// shortcut: get param count directly from metadata
316- return { UNK : parseInt ( header . __metadata__ . total_parameters . toString ( ) ) } ;
317- }
318301 const counter : Partial < Record < Dtype , number > > = { } ;
319302 const tensors = omit ( header , "__metadata__" ) ;
320303
0 commit comments