@@ -169,7 +169,7 @@ def from_pretrained_model(
169
169
cls ,
170
170
model : Module ,
171
171
sparsity_config : Union [SparsityCompressionConfig , str , None ] = None ,
172
- quantization_format : Optional [List [str ]] = None ,
172
+ quantization_format : Optional [Union [ str , List [str ] ]] = None ,
173
173
) -> Optional ["ModelCompressor" ]:
174
174
"""
175
175
Given a pytorch model and optional sparsity and/or quantization configs,
@@ -184,10 +184,14 @@ def from_pretrained_model(
184
184
"""
185
185
# assume multiple compression formats means mixed-precision
186
186
# as we currently only support one compressor per precision type and scheme
187
- if len (quantization_format ) > 1 :
188
- quantization_format = CompressionFormat .mixed_precision .value
189
- else :
190
- quantization_format = quantization_format [0 ]
187
+ if quantization_format is not None :
188
+ if isinstance (quantization_format , str ):
189
+ quantization_format = [quantization_format ]
190
+
191
+ if len (quantization_format ) > 1 :
192
+ quantization_format = CompressionFormat .mixed_precision .value
193
+ else :
194
+ quantization_format = quantization_format [0 ]
191
195
192
196
quantization_config = QuantizationConfig .from_pretrained (
193
197
model , format = quantization_format
@@ -408,12 +412,13 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]:
408
412
targets = scheme .targets ,
409
413
ignore = self .quantization_config .ignore ,
410
414
)
411
- unexpected_keys .update (
412
- merge_names (target , param )
413
- for target in quant_targets
414
- for param in self .quantization_compressor .compression_param_names
415
- if param != "weight"
416
- )
415
+ for quant_compressor in self .quantization_compressor .values ():
416
+ unexpected_keys .update (
417
+ merge_names (target , param )
418
+ for target in quant_targets
419
+ for param in quant_compressor .compression_param_names
420
+ if param != "weight"
421
+ )
417
422
418
423
return list (unexpected_keys )
419
424
@@ -451,9 +456,24 @@ def compress_model(self, model: Module):
451
456
452
457
# quantization first
453
458
if prefix in module_to_scheme :
454
- quant_compressor = self .quantization_compressor .get (
455
- module .quantization_scheme .format
456
- )
459
+ if (
460
+ not hasattr (module .quantization_scheme , "format" )
461
+ or module .quantization_scheme .format is None
462
+ ):
463
+ if (
464
+ self .quantization_config .format
465
+ == CompressionFormat .mixed_precision .value
466
+ ):
467
+ raise ValueError (
468
+ "Compressing mixed-precision models without defining "
469
+ "per module quantization_scheme.format is currently "
470
+ "not supported"
471
+ )
472
+ format = self .quantization_config .format
473
+ else :
474
+ format = module .quantization_scheme .format
475
+
476
+ quant_compressor = self .quantization_compressor .get (format )
457
477
state_dict = quant_compressor .compress (
458
478
state_dict ,
459
479
names_to_scheme = module_to_scheme ,
@@ -525,9 +545,24 @@ def decompress_model(self, model: Module):
525
545
526
546
# quantization second
527
547
if prefix in module_to_scheme :
528
- quant_compressor = self .quantization_compressor .get (
529
- module .quantization_scheme .format
530
- )
548
+
549
+ if (
550
+ not hasattr (module .quantization_scheme , "format" )
551
+ or module .quantization_scheme .format is None
552
+ ):
553
+ if (
554
+ self .quantization_config .format
555
+ == CompressionFormat .mixed_precision .value
556
+ ):
557
+ raise ValueError (
558
+ "Decompressing mixed-precision models without defining "
559
+ "per module quantization_scheme.format is currently not "
560
+ "supported"
561
+ )
562
+ format = self .quantization_config .format
563
+ else :
564
+ format = module .quantization_scheme .format
565
+ quant_compressor = self .quantization_compressor .get (format )
531
566
state_dict = quant_compressor .decompress_module_from_state_dict (
532
567
prefix ,
533
568
state_dict ,
@@ -621,15 +656,19 @@ def decompress(self, model_path: str, model: Module):
621
656
"""
622
657
model_path = get_safetensors_folder (model_path )
623
658
sparse_decompressed = False
659
+ quant_compressor = (
660
+ next (iter (self .quantization_compressor .values ()))
661
+ if self .quantization_compressor is not None
662
+ else None
663
+ )
624
664
625
665
if (
626
666
self .sparsity_compressor is not None
627
667
and self .sparsity_config .format != CompressionFormat .dense .value
628
668
):
629
669
# note - decompress only supports one compressor atm
630
- quant_compressor = next (iter (self .quantization_compressor .values ()))
631
670
params_to_ignore = None
632
- if self . quantization_compressor is not None :
671
+ if quant_compressor is not None :
633
672
params_to_ignore = quant_compressor .compression_param_names
634
673
# Sparse decompression is applied on the model_path
635
674
# The compressor will try and load any quantization parameters as well
0 commit comments