@@ -42,7 +42,20 @@ class SafetensorParseError extends Error {}
4242type  FileName  =  string ; 
4343
4444export  type  TensorName  =  string ; 
45- export  type  Dtype  =  "F64"  |  "F32"  |  "F16"  |  "BF16"  |  "I64"  |  "I32"  |  "I16"  |  "I8"  |  "U8"  |  "BOOL" ; 
45+ export  type  Dtype  = 
46+ 	|  "F64" 
47+ 	|  "F32" 
48+ 	|  "F16" 
49+ 	|  "F8_E4M3" 
50+ 	|  "F8_E5M2" 
51+ 	|  "BF16" 
52+ 	|  "I64" 
53+ 	|  "I32" 
54+ 	|  "I16" 
55+ 	|  "I8" 
56+ 	|  "U16" 
57+ 	|  "U8" 
58+ 	|  "BOOL" ; 
4659
4760export  interface  TensorInfo  { 
4861	dtype : Dtype ; 
@@ -51,13 +64,13 @@ export interface TensorInfo {
5164} 
5265
5366export  type  SafetensorsFileHeader  =  Record < TensorName ,  TensorInfo >  &  { 
54- 	__metadata__ : Record < string ,  string > ; 
67+ 	__metadata__ : {   total_parameters ?:  string   |   number   }   &   Record < string ,  string > ; 
5568} ; 
5669
5770export  interface  SafetensorsIndexJson  { 
5871	dtype ?: string ; 
5972	/// ^there's sometimes a dtype but it looks inconsistent. 
60- 	metadata ?: Record < string ,  string > ; 
73+ 	metadata ?: {   total_parameters ?:  string   |   number   }   &   Record < string ,  string > ; 
6174	/// ^ why the naming inconsistency? 
6275	weight_map : Record < TensorName ,  FileName > ; 
6376} 
@@ -69,12 +82,14 @@ export type SafetensorsParseFromRepo =
6982			sharded : false ; 
7083			header : SafetensorsFileHeader ; 
7184			parameterCount ?: Partial < Record < Dtype ,  number > > ; 
85+ 			parameterTotal ?: number ; 
7286	  } 
7387	|  { 
7488			sharded : true ; 
7589			index : SafetensorsIndexJson ; 
7690			headers : SafetensorsShardedHeaders ; 
7791			parameterCount ?: Partial < Record < Dtype ,  number > > ; 
92+ 			parameterTotal ?: number ; 
7893	  } ; 
7994
8095async  function  parseSingleFile ( 
@@ -127,7 +142,7 @@ async function parseShardedIndex(
127142		 */ 
128143		fetch ?: typeof  fetch ; 
129144	}  &  Partial < CredentialsParams > 
130- ) : Promise < {   index :  SafetensorsIndexJson ;   headers :  SafetensorsShardedHeaders   } >  { 
145+ ) : Promise < SafetensorsIndexJson >  { 
131146	const  indexBlob  =  await  downloadFile ( { 
132147		...params , 
133148		path, 
@@ -137,14 +152,28 @@ async function parseShardedIndex(
137152		throw  new  SafetensorParseError ( `Failed to parse file ${ path }  : failed to fetch safetensors index.` ) ; 
138153	} 
139154
140- 	// no validation for now, we assume it's a valid IndexJson. 
141- 	let  index : SafetensorsIndexJson ; 
142155	try  { 
143- 		index  =  JSON . parse ( await  indexBlob . slice ( 0 ,  10_000_000 ) . text ( ) ) ; 
156+ 		// no validation for now, we assume it's a valid IndexJson. 
157+ 		const  index  =  JSON . parse ( await  indexBlob . slice ( 0 ,  10_000_000 ) . text ( ) ) ; 
158+ 		return  index ; 
144159	}  catch  ( error )  { 
145160		throw  new  SafetensorParseError ( `Failed to parse file ${ path }  : not a valid JSON.` ) ; 
146161	} 
162+ } 
147163
164+ async  function  fetchAllHeaders ( 
165+ 	path : string , 
166+ 	index : SafetensorsIndexJson , 
167+ 	params : { 
168+ 		repo : RepoDesignation ; 
169+ 		revision ?: string ; 
170+ 		hubUrl ?: string ; 
171+ 		/** 
172+ 		 * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. 
173+ 		 */ 
174+ 		fetch ?: typeof  fetch ; 
175+ 	}  &  Partial < CredentialsParams > 
176+ ) : Promise < SafetensorsShardedHeaders >  { 
148177	const  pathPrefix  =  path . slice ( 0 ,  path . lastIndexOf ( "/" )  +  1 ) ; 
149178	const  filenames  =  [ ...new  Set ( Object . values ( index . weight_map ) ) ] ; 
150179	const  shardedMap : SafetensorsShardedHeaders  =  Object . fromEntries ( 
@@ -156,7 +185,7 @@ async function parseShardedIndex(
156185			PARALLEL_DOWNLOADS 
157186		) 
158187	) ; 
159- 	return  {  index ,   headers :  shardedMap   } ; 
188+ 	return  shardedMap ; 
160189} 
161190
162191/** 
@@ -189,12 +218,12 @@ export async function parseSafetensorsMetadata(
189218	params : { 
190219		/** Only models are supported */ 
191220		repo : RepoDesignation ; 
221+ 		path ?: string ; 
192222		/** 
193223		 * Will include SafetensorsParseFromRepo["parameterCount"], an object containing the number of parameters for each DType 
194224		 * 
195225		 * @default  false 
196226		 */ 
197- 		path ?: string ; 
198227		computeParametersCount ?: boolean ; 
199228		hubUrl ?: string ; 
200229		revision ?: string ; 
@@ -223,27 +252,55 @@ export async function parseSafetensorsMetadata(
223252		throw  new  TypeError ( "Only model repos should contain safetensors files." ) ; 
224253	} 
225254
226- 	if  ( RE_SAFETENSORS_FILE . test ( params . path  ??  "" )  ||  ( await  fileExists ( {  ...params ,  path : SAFETENSORS_FILE  } ) ) )  { 
255+ 	if  ( 
256+ 		( params . path  &&  RE_SAFETENSORS_FILE . test ( params . path ) )  || 
257+ 		( await  fileExists ( {  ...params ,  path : SAFETENSORS_FILE  } ) ) 
258+ 	)  { 
227259		const  header  =  await  parseSingleFile ( params . path  ??  SAFETENSORS_FILE ,  params ) ; 
228260		return  { 
229261			sharded : false , 
230262			header, 
231- 			...( params . computeParametersCount  &&  { 
232- 				parameterCount : computeNumOfParamsByDtypeSingleFile ( header ) , 
233- 			} ) , 
263+ 			...( params . computeParametersCount 
264+ 				? { 
265+ 						parameterCount : computeNumOfParamsByDtypeSingleFile ( header ) , 
266+ 						parameterTotal :
267+ 							/// shortcut: get param count directly from metadata 
268+ 							header . __metadata__ . total_parameters 
269+ 								? typeof  header . __metadata__ . total_parameters  ===  "number" 
270+ 									? header . __metadata__ . total_parameters 
271+ 									: typeof  header . __metadata__ . total_parameters  ===  "string" 
272+ 									  ? parseInt ( header . __metadata__ . total_parameters ) 
273+ 									  : undefined 
274+ 								: undefined , 
275+ 				  } 
276+ 				: undefined ) , 
234277		} ; 
235278	}  else  if  ( 
236- 		RE_SAFETENSORS_INDEX_FILE . test ( params . path   ??   "" )  || 
279+ 		( params . path   &&   RE_SAFETENSORS_INDEX_FILE . test ( params . path ) )  || 
237280		( await  fileExists ( {  ...params ,  path : SAFETENSORS_INDEX_FILE  } ) ) 
238281	)  { 
239- 		const  {  index,  headers }  =  await  parseShardedIndex ( params . path  ??  SAFETENSORS_INDEX_FILE ,  params ) ; 
282+ 		const  path  =  params . path  ??  SAFETENSORS_INDEX_FILE ; 
283+ 		const  index  =  await  parseShardedIndex ( path ,  params ) ; 
284+ 		const  shardedMap  =  await  fetchAllHeaders ( path ,  index ,  params ) ; 
285+ 
240286		return  { 
241287			sharded : true , 
242288			index, 
243- 			headers, 
244- 			...( params . computeParametersCount  &&  { 
245- 				parameterCount : computeNumOfParamsByDtypeSharded ( headers ) , 
246- 			} ) , 
289+ 			headers : shardedMap , 
290+ 			...( params . computeParametersCount 
291+ 				? { 
292+ 						parameterCount : computeNumOfParamsByDtypeSharded ( shardedMap ) , 
293+ 						parameterTotal :
294+ 							/// shortcut: get param count directly from metadata 
295+ 							index . metadata ?. total_parameters 
296+ 								? typeof  index . metadata . total_parameters  ===  "number" 
297+ 									? index . metadata . total_parameters 
298+ 									: typeof  index . metadata . total_parameters  ===  "string" 
299+ 									  ? parseInt ( index . metadata . total_parameters ) 
300+ 									  : undefined 
301+ 								: undefined , 
302+ 				  } 
303+ 				: undefined ) , 
247304		} ; 
248305	}  else  { 
249306		throw  new  Error ( "model id does not seem to contain safetensors weights" ) ; 
0 commit comments