@@ -141,7 +141,7 @@ async function parseShardedIndex(
141141 */
142142 fetch ?: typeof fetch ;
143143 } & Partial < CredentialsParams >
144- ) : Promise < { index : SafetensorsIndexJson ; headers : SafetensorsShardedHeaders } > {
144+ ) : Promise < SafetensorsIndexJson > {
145145 const indexBlob = await downloadFile ( {
146146 ...params ,
147147 path,
@@ -151,14 +151,28 @@ async function parseShardedIndex(
151151 throw new SafetensorParseError ( `Failed to parse file ${ path } : failed to fetch safetensors index.` ) ;
152152 }
153153
154- // no validation for now, we assume it's a valid IndexJson.
155- let index : SafetensorsIndexJson ;
156154 try {
157- index = JSON . parse ( await indexBlob . slice ( 0 , 10_000_000 ) . text ( ) ) ;
155+ // no validation for now, we assume it's a valid IndexJson.
156+ const index = JSON . parse ( await indexBlob . slice ( 0 , 10_000_000 ) . text ( ) ) ;
157+ return index ;
158158 } catch ( error ) {
159159 throw new SafetensorParseError ( `Failed to parse file ${ path } : not a valid JSON.` ) ;
160160 }
161+ }
161162
163+ async function fetchAllHeaders (
164+ path : string ,
165+ index : SafetensorsIndexJson ,
166+ params : {
167+ repo : RepoDesignation ;
168+ revision ?: string ;
169+ hubUrl ?: string ;
170+ /**
171+ * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
172+ */
173+ fetch ?: typeof fetch ;
174+ } & Partial < CredentialsParams >
175+ ) : Promise < SafetensorsShardedHeaders > {
162176 const pathPrefix = path . slice ( 0 , path . lastIndexOf ( "/" ) + 1 ) ;
163177 const filenames = [ ...new Set ( Object . values ( index . weight_map ) ) ] ;
164178 const shardedMap : SafetensorsShardedHeaders = Object . fromEntries (
@@ -170,7 +184,7 @@ async function parseShardedIndex(
170184 PARALLEL_DOWNLOADS
171185 )
172186 ) ;
173- return { index , headers : shardedMap } ;
187+ return shardedMap ;
174188}
175189
176190/**
@@ -191,6 +205,7 @@ export async function parseSafetensorsMetadata(
191205 * @default false
192206 */
193207 computeParametersCount : true ;
208+ fetchAllHeaders ?: boolean ;
194209 hubUrl ?: string ;
195210 revision ?: string ;
196211 /**
@@ -210,6 +225,12 @@ export async function parseSafetensorsMetadata(
210225 * @default false
211226 */
212227 computeParametersCount ?: boolean ;
228+ /**
229+ * Always fetch all headers (no shortcut)
230+ *
231+ * @default false
232+ */
233+ fetchAllHeaders ?: boolean ;
213234 hubUrl ?: string ;
214235 revision ?: string ;
215236 /**
@@ -223,6 +244,7 @@ export async function parseSafetensorsMetadata(
223244 repo : RepoDesignation ;
224245 path ?: string ;
225246 computeParametersCount ?: boolean ;
247+ fetchAllHeaders ?: boolean ;
226248 hubUrl ?: string ;
227249 revision ?: string ;
228250 /**
@@ -255,15 +277,31 @@ export async function parseSafetensorsMetadata(
255277 ( params . path && RE_SAFETENSORS_INDEX_FILE . test ( params . path ) ) ||
256278 ( await fileExists ( { ...params , path : SAFETENSORS_INDEX_FILE } ) )
257279 ) {
258- const { index, headers } = await parseShardedIndex ( params . path ?? SAFETENSORS_INDEX_FILE , params ) ;
280+ const path = params . path ?? SAFETENSORS_INDEX_FILE ;
281+ const index = await parseShardedIndex ( path , params ) ;
282+
283+ const shardedMap =
284+ params . fetchAllHeaders || ( params . computeParametersCount && ! index . metadata ?. total_parameters )
285+ ? await fetchAllHeaders ( path , index , params )
286+ : { } ;
287+
288+ if ( params . computeParametersCount && index . metadata ?. total_parameters ) {
289+ /// shortcut: get param count directly from metadata
290+ return {
291+ sharded : true ,
292+ index,
293+ headers : shardedMap ,
294+ parameterCount : { UNK : parseInt ( index . metadata . total_parameters . toString ( ) ) } ,
295+ } ;
296+ }
259297
260298 return {
261299 sharded : true ,
262300 index,
263- headers,
301+ headers : shardedMap ,
264302 ...( params . computeParametersCount
265303 ? {
266- parameterCount : computeNumOfParamsByDtypeSharded ( index , headers ) ,
304+ parameterCount : computeNumOfParamsByDtypeSharded ( shardedMap ) ,
267305 }
268306 : undefined ) ,
269307 } ;
@@ -289,14 +327,7 @@ function computeNumOfParamsByDtypeSingleFile(header: SafetensorsFileHeader): Par
289327 return counter ;
290328}
291329
292- function computeNumOfParamsByDtypeSharded (
293- index : SafetensorsIndexJson ,
294- shardedMap : SafetensorsShardedHeaders
295- ) : Partial < Record < Dtype , number > > {
296- if ( index . metadata ?. total_parameters ) {
297- /// shortcut: get param count directly from metadata
298- return { UNK : parseInt ( index . metadata . total_parameters . toString ( ) ) } ;
299- }
330+ function computeNumOfParamsByDtypeSharded ( shardedMap : SafetensorsShardedHeaders ) : Partial < Record < Dtype , number > > {
300331 const counter : Partial < Record < Dtype , number > > = { } ;
301332 for ( const header of Object . values ( shardedMap ) ) {
302333 for ( const [ k , v ] of typedEntries ( computeNumOfParamsByDtypeSingleFile ( header ) ) ) {
0 commit comments