@@ -169,7 +169,7 @@ def from_pretrained_model(
169
169
cls ,
170
170
model : Module ,
171
171
sparsity_config : Union [SparsityCompressionConfig , str , None ] = None ,
172
- quantization_format : Optional [str ] = None ,
172
+ quantization_format : Optional [Union [ str , List [ str ]] ] = None ,
173
173
) -> Optional ["ModelCompressor" ]:
174
174
"""
175
175
Given a pytorch model and optional sparsity and/or quantization configs,
@@ -182,7 +182,6 @@ def from_pretrained_model(
182
182
algorithm
183
183
:return: compressor for the configs, or None if model is not compressed
184
184
"""
185
- # reconstruct config from schemes attached to modules
186
185
quantization_config = QuantizationConfig .from_pretrained (
187
186
model , format = quantization_format
188
187
)
@@ -203,6 +202,9 @@ def from_pretrained_model(
203
202
sparsity_config = sparsity_config ,
204
203
quantization_config = quantization_config ,
205
204
transform_config = transform_config ,
205
+ compression_formats = [quantization_format ]
206
+ if isinstance (quantization_format , str )
207
+ else quantization_format ,
206
208
)
207
209
208
210
@staticmethod
@@ -263,30 +265,61 @@ def parse_quantization_config(
263
265
264
266
return quantization_config
265
267
268
+ def _fetch_unique_quantization_formats (self ) -> List [str ]:
269
+ """
270
+ Get all unique compression formats present in a model.
271
+ :return: list of quantization formats
272
+ """
273
+ quantization_formats = []
274
+ for _ , scheme in self .quantization_config .config_groups .items ():
275
+ if scheme .format is not None and scheme .format not in quantization_formats :
276
+ quantization_formats .append (scheme .format )
277
+
278
+ if (
279
+ len (quantization_formats ) == 0
280
+ and self .quantization_config .format
281
+ != CompressionFormat .mixed_precision .value
282
+ ):
283
+ quantization_formats .append (self .quantization_config .format )
284
+ return quantization_formats
285
+
266
286
def __init__ (
267
287
self ,
268
288
sparsity_config : Optional [SparsityCompressionConfig ] = None ,
269
289
quantization_config : Optional [QuantizationConfig ] = None ,
270
290
transform_config : Optional [TransformConfig ] = None ,
291
+ compression_formats : Optional [List [str ]] = None ,
271
292
):
272
293
self .sparsity_config = sparsity_config
273
294
self .quantization_config = quantization_config
274
295
self .transform_config = transform_config
296
+ self .compression_formats = compression_formats
275
297
276
298
self .sparsity_compressor = None
277
299
self .quantization_compressor : Optional [
278
- Union [BaseQuantizationCompressor , DenseCompressor ]
300
+ Dict [ str , Union [BaseQuantizationCompressor , DenseCompressor ] ]
279
301
] = None
280
302
# no transform compressor is required
281
303
282
304
if sparsity_config is not None :
283
305
self .sparsity_compressor = BaseCompressor .load_from_registry (
284
306
sparsity_config .format , config = sparsity_config
285
307
)
308
+
286
309
if quantization_config is not None :
287
- self .quantization_compressor = BaseCompressor .load_from_registry (
288
- quantization_config .format , config = quantization_config
289
- )
310
+ # If a list of compression_format is not provided, we resolve the
311
+ # relevant quantization formats using the config groups from the config
312
+ # and if those are not defined, we fall-back to the global quantization format
313
+ if not self .compression_formats :
314
+ self .compression_formats = self ._fetch_unique_quantization_formats ()
315
+
316
+ self .quantization_compressor = {}
317
+ for format in self .compression_formats :
318
+ self .quantization_compressor [
319
+ format
320
+ ] = BaseCompressor .load_from_registry (
321
+ format , config = quantization_config
322
+ )
290
323
291
324
# ----- used by hf quantizer ----- #
292
325
@@ -381,12 +414,13 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]:
381
414
targets = scheme .targets ,
382
415
ignore = self .quantization_config .ignore ,
383
416
)
384
- unexpected_keys .update (
385
- merge_names (target , param )
386
- for target in quant_targets
387
- for param in self .quantization_compressor .compression_param_names
388
- if param != "weight"
389
- )
417
+ for quant_compressor in self .quantization_compressor .values ():
418
+ unexpected_keys .update (
419
+ merge_names (target , param )
420
+ for target in quant_targets
421
+ for param in quant_compressor .compression_param_names
422
+ if param != "weight"
423
+ )
390
424
391
425
return list (unexpected_keys )
392
426
@@ -424,7 +458,21 @@ def compress_model(self, model: Module):
424
458
425
459
# quantization first
426
460
if prefix in module_to_scheme :
427
- state_dict = self .quantization_compressor .compress (
461
+ if (
462
+ not hasattr (module .quantization_scheme , "format" )
463
+ or module .quantization_scheme .format is None
464
+ ):
465
+ if len (self .compression_formats ) > 1 :
466
+ raise ValueError (
467
+ "Applying multiple compressors without defining "
468
+ "per module formats is not supported "
469
+ )
470
+ format = self .compression_formats [0 ]
471
+ else :
472
+ format = module .quantization_scheme .format
473
+
474
+ quant_compressor = self .quantization_compressor .get (format )
475
+ state_dict = quant_compressor .compress (
428
476
state_dict ,
429
477
names_to_scheme = module_to_scheme ,
430
478
show_progress = False ,
@@ -495,12 +543,24 @@ def decompress_model(self, model: Module):
495
543
496
544
# quantization second
497
545
if prefix in module_to_scheme :
498
- state_dict = (
499
- self .quantization_compressor .decompress_module_from_state_dict (
500
- prefix ,
501
- state_dict ,
502
- scheme = module_to_scheme [prefix ],
503
- )
546
+
547
+ if (
548
+ not hasattr (module .quantization_scheme , "format" )
549
+ or module .quantization_scheme .format is None
550
+ ):
551
+ if len (self .compression_formats ) > 1 :
552
+ raise ValueError (
553
+ "Applying multiple compressors without defining "
554
+ "per module formats is not supported "
555
+ )
556
+ format = self .compression_formats [0 ]
557
+ else :
558
+ format = module .quantization_scheme .format
559
+ quant_compressor = self .quantization_compressor .get (format )
560
+ state_dict = quant_compressor .decompress_module_from_state_dict (
561
+ prefix ,
562
+ state_dict ,
563
+ scheme = module_to_scheme [prefix ],
504
564
)
505
565
506
566
# remove any existing parameters
@@ -539,7 +599,9 @@ def compress(
539
599
540
600
if self .quantization_compressor is not None :
541
601
module_to_scheme = map_module_to_scheme (model )
542
- state_dict = self .quantization_compressor .compress (
602
+ # Note - compress only supports one compression format atm
603
+ quant_compressor = next (iter (self .quantization_compressor .values ()))
604
+ state_dict = quant_compressor .compress (
543
605
state_dict ,
544
606
names_to_scheme = module_to_scheme ,
545
607
show_progress = show_progress ,
@@ -588,14 +650,20 @@ def decompress(self, model_path: str, model: Module):
588
650
"""
589
651
model_path = get_safetensors_folder (model_path )
590
652
sparse_decompressed = False
653
+ quant_compressor = (
654
+ next (iter (self .quantization_compressor .values ()))
655
+ if self .quantization_compressor is not None
656
+ else None
657
+ )
591
658
592
659
if (
593
660
self .sparsity_compressor is not None
594
661
and self .sparsity_config .format != CompressionFormat .dense .value
595
662
):
663
+ # note - decompress only supports one compressor atm
596
664
params_to_ignore = None
597
- if self . quantization_compressor is not None :
598
- params_to_ignore = self . quantization_compressor .compression_param_names
665
+ if quant_compressor is not None :
666
+ params_to_ignore = quant_compressor .compression_param_names
599
667
# Sparse decompression is applied on the model_path
600
668
# The compressor will try and load any quantization parameters as well
601
669
# params_to_skip_load will skip over quantization params from being loaded
@@ -606,7 +674,7 @@ def decompress(self, model_path: str, model: Module):
606
674
setattr (model , SPARSITY_CONFIG_NAME , self .sparsity_compressor .config )
607
675
sparse_decompressed = True
608
676
609
- if self . quantization_compressor is not None :
677
+ if quant_compressor is not None :
610
678
# Temporarily set quantization status to FROZEN to prevent
611
679
# quantization during apply_quantization_config. This ensures
612
680
# that the dtypes of the weights are not unintentionally updated.
@@ -629,15 +697,15 @@ def decompress(self, model_path: str, model: Module):
629
697
# including initialization
630
698
load_weight_quantization = (
631
699
sparse_decompressed
632
- or isinstance (self . quantization_compressor , DenseCompressor )
700
+ or isinstance (quant_compressor , DenseCompressor )
633
701
),
634
702
)
635
703
636
704
model_path_or_state_dict = (
637
705
model .state_dict () if sparse_decompressed else model_path
638
706
)
639
707
640
- dense_gen = self . quantization_compressor .decompress (
708
+ dense_gen = quant_compressor .decompress (
641
709
model_path_or_state_dict , names_to_scheme = names_to_scheme
642
710
)
643
711
# TODO: all weight quantization params will be moved to the compressor
0 commit comments