@@ -48,13 +48,19 @@ export type Dtype =
4848 | "F16"
4949 | "F8_E4M3"
5050 | "F8_E5M2"
51+ | "E8M0"
52+ | "F6_E3M2"
53+ | "F6_E2M3"
54+ | "F4"
55+ | "FP4"
5156 | "BF16"
5257 | "I64"
5358 | "I32"
5459 | "I16"
5560 | "I8"
5661 | "U16"
5762 | "U8"
63+ | "UE8"
5864 | "BOOL" ;
5965
6066export interface TensorInfo {
@@ -92,6 +98,35 @@ export type SafetensorsParseFromRepo =
9298 parameterTotal ?: number ;
9399 } ;
94100
101+ /**
102+ * Fetches and parses model config.json
103+ */
104+ async function fetchModelConfig (
105+ params : {
106+ repo : RepoDesignation ;
107+ revision ?: string ;
108+ hubUrl ?: string ;
109+ fetch ?: typeof fetch ;
110+ } & Partial < CredentialsParams >
111+ ) : Promise < ModelConfig | null > {
112+ try {
113+ const configBlob = await downloadFile ( {
114+ ...params ,
115+ path : "config.json" ,
116+ } ) ;
117+
118+ if ( ! configBlob ) {
119+ return null ;
120+ }
121+
122+ const config = JSON . parse ( await configBlob . text ( ) ) ;
123+ return config as ModelConfig ;
124+ } catch ( error ) {
125+ // Config file might not exist or be inaccessible
126+ return null ;
127+ }
128+ }
129+
95130async function parseSingleFile (
96131 path : string ,
97132 params : {
@@ -252,6 +287,10 @@ export async function parseSafetensorsMetadata(
252287 throw new TypeError ( "Only model repos should contain safetensors files." ) ;
253288 }
254289
290+ // Fetch model config for quantization information
291+ const modelConfig = params . computeParametersCount ? await fetchModelConfig ( params ) : null ;
292+ const quantConfig = modelConfig ?. quantization_config ;
293+
255294 if (
256295 ( params . path && RE_SAFETENSORS_FILE . test ( params . path ) ) ||
257296 ( await fileExists ( { ...params , path : SAFETENSORS_FILE } ) )
@@ -262,7 +301,7 @@ export async function parseSafetensorsMetadata(
262301 header,
263302 ...( params . computeParametersCount
264303 ? {
265- parameterCount : computeNumOfParamsByDtypeSingleFile ( header ) ,
304+ parameterCount : computeNumOfParamsByDtypeSingleFile ( header , quantConfig ) ,
266305 parameterTotal :
267306 /// shortcut: get param count directly from metadata
268307 header . __metadata__ . total_parameters
@@ -289,7 +328,7 @@ export async function parseSafetensorsMetadata(
289328 headers : shardedMap ,
290329 ...( params . computeParametersCount
291330 ? {
292- parameterCount : computeNumOfParamsByDtypeSharded ( shardedMap ) ,
331+ parameterCount : computeNumOfParamsByDtypeSharded ( shardedMap , quantConfig ) ,
293332 parameterTotal :
294333 /// shortcut: get param count directly from metadata
295334 index . metadata ?. total_parameters
@@ -307,23 +346,108 @@ export async function parseSafetensorsMetadata(
307346 }
308347}
309348
310- function computeNumOfParamsByDtypeSingleFile ( header : SafetensorsFileHeader ) : Partial < Record < Dtype , number > > {
349+ export interface QuantizationConfig {
350+ quant_method ?: string ;
351+ modules_to_not_convert ?: string [ ] ;
352+ bits ?: number ;
353+ load_in_4bit ?: boolean ;
354+ load_in_8bit ?: boolean ;
355+ }
356+
357+ export interface ModelConfig {
358+ quantization_config ?: QuantizationConfig ;
359+ }
360+
361+ /**
362+ * Determines if a tensor is quantized based on quantization config and tensor name
363+ */
364+ function isQuantizedTensor ( tensorName : string , quantConfig ?: QuantizationConfig ) : boolean {
365+ if ( ! quantConfig || ! quantConfig . modules_to_not_convert ) {
366+ return false ;
367+ }
368+
369+ for ( const pattern of quantConfig . modules_to_not_convert ) {
370+ const regexPattern = pattern . replace ( / \* / g, ".*" ) ;
371+ const regex = new RegExp ( regexPattern ) ;
372+ if ( regex . test ( tensorName ) ) {
373+ return false ;
374+ }
375+ }
376+
377+ return true ;
378+ }
379+
380+ /**
381+ * Gets the parameter multiplier for a quantized tensor based on quantization method
382+ */
383+ function getQuantizationMultiplier ( tensorName : string , dtype : Dtype , quantConfig ?: QuantizationConfig ) : number {
384+ if ( ! quantConfig || ! isQuantizedTensor ( tensorName , quantConfig ) ) {
385+ return 1 ;
386+ }
387+
388+ switch ( quantConfig . quant_method ) {
389+ case "mxfp4" :
390+ if ( dtype === "U8" && tensorName . includes ( "_blocks" ) ) {
391+ return 2 ;
392+ }
393+ return 1 ;
394+
395+ case "gptq" :
396+ case "awq" :
397+ if ( quantConfig . bits === 4 && dtype === "U8" ) {
398+ return 2 ;
399+ }
400+ if ( quantConfig . bits === 2 && dtype === "U8" ) {
401+ return 4 ;
402+ }
403+ return 1 ;
404+
405+ case "bitsandbytes" :
406+ if ( quantConfig . load_in_4bit && dtype === "U8" ) {
407+ return 2 ;
408+ }
409+ if ( quantConfig . load_in_8bit ) {
410+ return 1 ;
411+ }
412+ return 1 ;
413+
414+ default :
415+ if ( quantConfig . load_in_4bit && dtype === "U8" ) {
416+ return 2 ;
417+ }
418+ if ( quantConfig . bits === 4 && dtype === "U8" ) {
419+ return 2 ;
420+ }
421+ return 1 ;
422+ }
423+ }
424+
425+ function computeNumOfParamsByDtypeSingleFile (
426+ header : SafetensorsFileHeader ,
427+ quantConfig ?: QuantizationConfig
428+ ) : Partial < Record < Dtype , number > > {
311429 const counter : Partial < Record < Dtype , number > > = { } ;
312430 const tensors = omit ( header , "__metadata__" ) ;
313431
314- for ( const [ , v ] of typedEntries ( tensors ) ) {
432+ for ( const [ tensorName , v ] of typedEntries ( tensors ) ) {
315433 if ( v . shape . length === 0 ) {
316434 continue ;
317435 }
318- counter [ v . dtype ] = ( counter [ v . dtype ] ?? 0 ) + v . shape . reduce ( ( a , b ) => a * b ) ;
436+
437+ const elements = v . shape . reduce ( ( a , b ) => a * b ) ;
438+ const multiplier = quantConfig ? getQuantizationMultiplier ( tensorName , v . dtype , quantConfig ) : 1 ;
439+ counter [ v . dtype ] = ( counter [ v . dtype ] ?? 0 ) + elements * multiplier ;
319440 }
320441 return counter ;
321442}
322443
323- function computeNumOfParamsByDtypeSharded ( shardedMap : SafetensorsShardedHeaders ) : Partial < Record < Dtype , number > > {
444+ function computeNumOfParamsByDtypeSharded (
445+ shardedMap : SafetensorsShardedHeaders ,
446+ quantConfig ?: QuantizationConfig
447+ ) : Partial < Record < Dtype , number > > {
324448 const counter : Partial < Record < Dtype , number > > = { } ;
325449 for ( const header of Object . values ( shardedMap ) ) {
326- for ( const [ k , v ] of typedEntries ( computeNumOfParamsByDtypeSingleFile ( header ) ) ) {
450+ for ( const [ k , v ] of typedEntries ( computeNumOfParamsByDtypeSingleFile ( header , quantConfig ) ) ) {
327451 counter [ k ] = ( counter [ k ] ?? 0 ) + ( v ?? 0 ) ;
328452 }
329453 }
0 commit comments