@@ -164,7 +164,7 @@ def from_pretrained_model(
164
164
cls ,
165
165
model : Module ,
166
166
sparsity_config : Union [SparsityCompressionConfig , str , None ] = None ,
167
- quantization_format : Optional [List [str ]] = None ,
167
+ quantization_format : Optional [Union [ str , List [str ] ]] = None ,
168
168
) -> Optional ["ModelCompressor" ]:
169
169
"""
170
170
Given a pytorch model and optional sparsity and/or quantization configs,
@@ -179,10 +179,14 @@ def from_pretrained_model(
179
179
"""
180
180
# assume multiple compression formats means mixed-precision
181
181
# as we currently only support one compressor per precision type and scheme
182
- if len (quantization_format ) > 1 :
183
- quantization_format = CompressionFormat .mixed_precision .value
184
- else :
185
- quantization_format = quantization_format [0 ]
182
+ if quantization_format is not None :
183
+ if isinstance (quantization_format , str ):
184
+ quantization_format = [quantization_format ]
185
+
186
+ if len (quantization_format ) > 1 :
187
+ quantization_format = CompressionFormat .mixed_precision .value
188
+ else :
189
+ quantization_format = quantization_format [0 ]
186
190
187
191
quantization_config = QuantizationConfig .from_pretrained (
188
192
model , format = quantization_format
@@ -393,12 +397,13 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]:
393
397
targets = scheme .targets ,
394
398
ignore = self .quantization_config .ignore ,
395
399
)
396
- unexpected_keys .update (
397
- merge_names (target , param )
398
- for target in quant_targets
399
- for param in self .quantization_compressor .compression_param_names
400
- if param != "weight"
401
- )
400
+ for quant_compressor in self .quantization_compressor .values ():
401
+ unexpected_keys .update (
402
+ merge_names (target , param )
403
+ for target in quant_targets
404
+ for param in quant_compressor .compression_param_names
405
+ if param != "weight"
406
+ )
402
407
403
408
return list (unexpected_keys )
404
409
@@ -436,9 +441,24 @@ def compress_model(self, model: Module):
436
441
437
442
# quantization first
438
443
if prefix in module_to_scheme :
439
- quant_compressor = self .quantization_compressor .get (
440
- module .quantization_scheme .format
441
- )
444
+ if (
445
+ not hasattr (module .quantization_scheme , "format" )
446
+ or module .quantization_scheme .format is None
447
+ ):
448
+ if (
449
+ self .quantization_config .format
450
+ == CompressionFormat .mixed_precision .value
451
+ ):
452
+ raise ValueError (
453
+ "Compressing mixed-precision models without defining "
454
+ "per module quantization_scheme.format is currently "
455
+ "not supported"
456
+ )
457
+ format = self .quantization_config .format
458
+ else :
459
+ format = module .quantization_scheme .format
460
+
461
+ quant_compressor = self .quantization_compressor .get (format )
442
462
state_dict = quant_compressor .compress (
443
463
state_dict ,
444
464
names_to_scheme = module_to_scheme ,
@@ -510,9 +530,24 @@ def decompress_model(self, model: Module):
510
530
511
531
# quantization second
512
532
if prefix in module_to_scheme :
513
- quant_compressor = self .quantization_compressor .get (
514
- module .quantization_scheme .format
515
- )
533
+
534
+ if (
535
+ not hasattr (module .quantization_scheme , "format" )
536
+ or module .quantization_scheme .format is None
537
+ ):
538
+ if (
539
+ self .quantization_config .format
540
+ == CompressionFormat .mixed_precision .value
541
+ ):
542
+ raise ValueError (
543
+ "Decompressing mixed-precision models without defining "
544
+ "per module quantization_scheme.format is currently not "
545
+ "supported"
546
+ )
547
+ format = self .quantization_config .format
548
+ else :
549
+ format = module .quantization_scheme .format
550
+ quant_compressor = self .quantization_compressor .get (format )
516
551
state_dict = quant_compressor .decompress_module_from_state_dict (
517
552
prefix ,
518
553
state_dict ,
@@ -606,15 +641,19 @@ def decompress(self, model_path: str, model: Module):
606
641
"""
607
642
model_path = get_safetensors_folder (model_path )
608
643
sparse_decompressed = False
644
+ quant_compressor = (
645
+ next (iter (self .quantization_compressor .values ()))
646
+ if self .quantization_compressor is not None
647
+ else None
648
+ )
609
649
610
650
if (
611
651
self .sparsity_compressor is not None
612
652
and self .sparsity_config .format != CompressionFormat .dense .value
613
653
):
614
654
# note - decompress only supports one compressor atm
615
- quant_compressor = next (iter (self .quantization_compressor .values ()))
616
655
params_to_ignore = None
617
- if self . quantization_compressor is not None :
656
+ if quant_compressor is not None :
618
657
params_to_ignore = quant_compressor .compression_param_names
619
658
# Sparse decompression is applied on the model_path
620
659
# The compressor will try and load any quantization parameters as well
0 commit comments