@@ -182,7 +182,13 @@ def from_pretrained_model(
182
182
algorithm
183
183
:return: compressor for the configs, or None if model is not compressed
184
184
"""
185
- # reconstruct config from schemes attached to modules
185
+ # assume multiple compression formats means mixed-precision
186
+ # as we currently only support one compressor per precision type and scheme
187
+ if len (quantization_format ) > 1 :
188
+ quantization_format = CompressionFormat .mixed_precision
189
+ else :
190
+ quantization_format = quantization_format [0 ]
191
+
186
192
quantization_config = QuantizationConfig .from_pretrained (
187
193
model , format = quantization_format
188
194
)
@@ -263,6 +269,17 @@ def parse_quantization_config(
263
269
264
270
return quantization_config
265
271
272
+ def _fetch_unique_quantization_formats (self ):
273
+ """
274
+ Get all unique compression formats used in
275
+ model
276
+ """
277
+ quantization_formats = []
278
+ for _ , scheme in self .quantization_config .config_groups .items ():
279
+ if scheme .format not in quantization_formats :
280
+ quantization_formats .append (scheme )
281
+ return quantization_formats
282
+
266
283
def __init__ (
267
284
self ,
268
285
sparsity_config : Optional [SparsityCompressionConfig ] = None ,
@@ -275,26 +292,24 @@ def __init__(
275
292
276
293
self .sparsity_compressor = None
277
294
self .quantization_compressor : Optional [
278
- Union [BaseQuantizationCompressor , DenseCompressor ]
295
+ Dict [ str , Union [BaseQuantizationCompressor , DenseCompressor ] ]
279
296
] = None
280
297
# no transform compressor is required
281
298
282
299
if sparsity_config is not None :
283
300
self .sparsity_compressor = BaseCompressor .load_from_registry (
284
301
sparsity_config .format , config = sparsity_config
285
302
)
303
+
304
+ quantization_formats = self ._fetch_unique_quantization_formats ()
305
+
286
306
if quantization_config is not None :
287
- if isinstance (quantization_config .format , list ):
288
- self .quantization_compressor = {}
289
- for format in quantization_config .format :
290
- self .quantization_compressor [
291
- format
292
- ] = BaseCompressor .load_from_registry (
293
- format , config = quantization_config
294
- )
295
- else :
296
- self .quantization_compressor = BaseCompressor .load_from_registry (
297
- quantization_config .format , config = quantization_config
307
+ self .quantization_compressor = {}
308
+ for format in quantization_formats :
309
+ self .quantization_compressor [
310
+ format
311
+ ] = BaseCompressor .load_from_registry (
312
+ format , config = quantization_config
298
313
)
299
314
300
315
# ----- used by hf quantizer ----- #
@@ -433,23 +448,15 @@ def compress_model(self, model: Module):
433
448
434
449
# quantization first
435
450
if prefix in module_to_scheme :
436
- if isinstance (self .quantization_compressor , dict ):
437
- quant_compressor = self .quantization_compressor .get (
438
- module .quantization_scheme .format
439
- )
440
- state_dict = quant_compressor .compress (
441
- state_dict ,
442
- names_to_scheme = module_to_scheme ,
443
- show_progress = False ,
444
- compression_device = exec_device ,
445
- )
446
- else :
447
- state_dict = self .quantization_compressor .compress (
448
- state_dict ,
449
- names_to_scheme = module_to_scheme ,
450
- show_progress = False ,
451
- compression_device = exec_device ,
452
- )
451
+ quant_compressor = self .quantization_compressor .get (
452
+ module .quantization_scheme .format
453
+ )
454
+ state_dict = quant_compressor .compress (
455
+ state_dict ,
456
+ names_to_scheme = module_to_scheme ,
457
+ show_progress = False ,
458
+ compression_device = exec_device ,
459
+ )
453
460
454
461
# sparsity second
455
462
if prefix in sparse_compression_targets :
@@ -515,12 +522,13 @@ def decompress_model(self, model: Module):
515
522
516
523
# quantization second
517
524
if prefix in module_to_scheme :
518
- state_dict = (
519
- self .quantization_compressor .decompress_module_from_state_dict (
520
- prefix ,
521
- state_dict ,
522
- scheme = module_to_scheme [prefix ],
523
- )
525
+ quant_compressor = self .quantization_compressor .get (
526
+ module .quantization_scheme .format
527
+ )
528
+ state_dict = quant_compressor .decompress_module_from_state_dict (
529
+ prefix ,
530
+ state_dict ,
531
+ scheme = module_to_scheme [prefix ],
524
532
)
525
533
526
534
# remove any existing parameters
@@ -559,7 +567,9 @@ def compress(
559
567
560
568
if self .quantization_compressor is not None :
561
569
module_to_scheme = map_module_to_scheme (model )
562
- state_dict = self .quantization_compressor .compress (
570
+ # Note - compress only supports one compression format atm
571
+ quant_compressor = next (iter (self .quantization_compressor ))
572
+ state_dict = quant_compressor .compress (
563
573
state_dict ,
564
574
names_to_scheme = module_to_scheme ,
565
575
show_progress = show_progress ,
@@ -613,9 +623,11 @@ def decompress(self, model_path: str, model: Module):
613
623
self .sparsity_compressor is not None
614
624
and self .sparsity_config .format != CompressionFormat .dense .value
615
625
):
626
+ # note - decompress only support one compressor so far
627
+ quant_compressor = next (iter (self .quantization_compressor ))
616
628
params_to_ignore = None
617
629
if self .quantization_compressor is not None :
618
- params_to_ignore = self . quantization_compressor .compression_param_names
630
+ params_to_ignore = quant_compressor .compression_param_names
619
631
# Sparse decompression is applied on the model_path
620
632
# The compressor will try and load any quantization parameters as well
621
633
# params_to_skip_load will skip over quantization params from being loaded
@@ -626,7 +638,7 @@ def decompress(self, model_path: str, model: Module):
626
638
setattr (model , SPARSITY_CONFIG_NAME , self .sparsity_compressor .config )
627
639
sparse_decompressed = True
628
640
629
- if self . quantization_compressor is not None :
641
+ if quant_compressor is not None :
630
642
# Temporarily set quantization status to FROZEN to prevent
631
643
# quantization during apply_quantization_config. This ensures
632
644
# that the dtypes of the weights are not unintentionally updated.
@@ -649,15 +661,15 @@ def decompress(self, model_path: str, model: Module):
649
661
# including initialization
650
662
load_weight_quantization = (
651
663
sparse_decompressed
652
- or isinstance (self . quantization_compressor , DenseCompressor )
664
+ or isinstance (quant_compressor , DenseCompressor )
653
665
),
654
666
)
655
667
656
668
model_path_or_state_dict = (
657
669
model .state_dict () if sparse_decompressed else model_path
658
670
)
659
671
660
- dense_gen = self . quantization_compressor .decompress (
672
+ dense_gen = quant_compressor .decompress (
661
673
model_path_or_state_dict , names_to_scheme = names_to_scheme
662
674
)
663
675
# TODO: all weight quantization params will be moved to the compressor
0 commit comments