@@ -65,13 +65,13 @@ export interface TensorInfo {
6565}
6666
6767export type SafetensorsFileHeader = Record < TensorName , TensorInfo > & {
68- __metadata__ : Record < string , string > ;
68+ __metadata__ : { total_parameters ?: string | number } & Record < string , string > ;
6969} ;
7070
7171export interface SafetensorsIndexJson {
7272 dtype ?: string ;
7373 /// ^there's sometimes a dtype but it looks inconsistent.
74- metadata ?: Record < string , string > ;
74+ metadata ?: { total_parameters ?: string | number } & Record < string , string > ;
7575 /// ^ why the naming inconsistency?
7676 weight_map : Record < TensorName , FileName > ;
7777}
@@ -256,13 +256,14 @@ export async function parseSafetensorsMetadata(
256256 ( await fileExists ( { ...params , path : SAFETENSORS_INDEX_FILE } ) )
257257 ) {
258258 const { index, headers } = await parseShardedIndex ( params . path ?? SAFETENSORS_INDEX_FILE , params ) ;
259+
259260 return {
260261 sharded : true ,
261262 index,
262263 headers,
263264 ...( params . computeParametersCount
264265 ? {
265- parameterCount : computeNumOfParamsByDtypeSharded ( headers ) ,
266+ parameterCount : computeNumOfParamsByDtypeSharded ( index , headers ) ,
266267 }
267268 : undefined ) ,
268269 } ;
@@ -272,6 +273,10 @@ export async function parseSafetensorsMetadata(
272273}
273274
274275function computeNumOfParamsByDtypeSingleFile ( header : SafetensorsFileHeader ) : Partial < Record < Dtype , number > > {
276+ if ( header . __metadata__ . total_parameters ) {
277+ /// shortcut: get param count directly from metadata
278+ return { UNK : parseInt ( header . __metadata__ . total_parameters . toString ( ) ) } ;
279+ }
275280 const counter : Partial < Record < Dtype , number > > = { } ;
276281 const tensors = omit ( header , "__metadata__" ) ;
277282
@@ -284,7 +289,14 @@ function computeNumOfParamsByDtypeSingleFile(header: SafetensorsFileHeader): Par
284289 return counter ;
285290}
286291
287- function computeNumOfParamsByDtypeSharded ( shardedMap : SafetensorsShardedHeaders ) : Partial < Record < Dtype , number > > {
292+ function computeNumOfParamsByDtypeSharded (
293+ index : SafetensorsIndexJson ,
294+ shardedMap : SafetensorsShardedHeaders
295+ ) : Partial < Record < Dtype , number > > {
296+ if ( index . metadata ?. total_parameters ) {
297+ /// shortcut: get param count directly from metadata
298+ return { UNK : parseInt ( index . metadata . total_parameters . toString ( ) ) } ;
299+ }
288300 const counter : Partial < Record < Dtype , number > > = { } ;
289301 for ( const header of Object . values ( shardedMap ) ) {
290302 for ( const [ k , v ] of typedEntries ( computeNumOfParamsByDtypeSingleFile ( header ) ) ) {
0 commit comments