@@ -177,6 +177,13 @@ def from_pretrained_model(
177
177
algorithm
178
178
:return: compressor for the configs, or None if model is not compressed
179
179
"""
180
+ # assume multiple compression formats means mixed-precision
181
+ # as we currently only support one compressor per precision type and scheme
182
+ if len (quantization_format ) > 1 :
183
+ quantization_format = CompressionFormat .mixed_precision
184
+ else :
185
+ quantization_format = quantization_format [0 ]
186
+
180
187
quantization_config = QuantizationConfig .from_pretrained (
181
188
model , format = quantization_format
182
189
)
@@ -190,7 +197,8 @@ def from_pretrained_model(
190
197
return None
191
198
192
199
return cls (
193
- sparsity_config = sparsity_config , quantization_config = quantization_config
200
+ sparsity_config = sparsity_config ,
201
+ quantization_config = quantization_config ,
194
202
)
195
203
196
204
@staticmethod
@@ -250,6 +258,17 @@ def parse_quantization_config(
250
258
251
259
return quantization_config
252
260
261
+ def _fetch_unique_quantization_formats (self ):
262
+ """
263
+ Get all unique compression formats used in
264
+ model
265
+ """
266
+ quantization_formats = []
267
+ for _ , scheme in self .quantization_config .config_groups .items ():
268
+ if scheme .format not in quantization_formats :
269
+ quantization_formats .append (scheme )
270
+ return quantization_formats
271
+
253
272
def __init__ (
254
273
self ,
255
274
sparsity_config : Optional [SparsityCompressionConfig ] = None ,
@@ -259,25 +278,23 @@ def __init__(
259
278
self .quantization_config = quantization_config
260
279
self .sparsity_compressor = None
261
280
self .quantization_compressor : Optional [
262
- Union [BaseQuantizationCompressor , DenseCompressor ]
281
+ Dict [ str , Union [BaseQuantizationCompressor , DenseCompressor ] ]
263
282
] = None
264
283
265
284
if sparsity_config is not None :
266
285
self .sparsity_compressor = BaseCompressor .load_from_registry (
267
286
sparsity_config .format , config = sparsity_config
268
287
)
288
+
289
+ quantization_formats = self ._fetch_unique_quantization_formats ()
290
+
269
291
if quantization_config is not None :
270
- if isinstance (quantization_config .format , list ):
271
- self .quantization_compressor = {}
272
- for format in quantization_config .format :
273
- self .quantization_compressor [
274
- format
275
- ] = BaseCompressor .load_from_registry (
276
- format , config = quantization_config
277
- )
278
- else :
279
- self .quantization_compressor = BaseCompressor .load_from_registry (
280
- quantization_config .format , config = quantization_config
292
+ self .quantization_compressor = {}
293
+ for format in quantization_formats :
294
+ self .quantization_compressor [
295
+ format
296
+ ] = BaseCompressor .load_from_registry (
297
+ format , config = quantization_config
281
298
)
282
299
283
300
# ----- used by hf quantizer ----- #
@@ -416,23 +433,15 @@ def compress_model(self, model: Module):
416
433
417
434
# quantization first
418
435
if prefix in module_to_scheme :
419
- if isinstance (self .quantization_compressor , dict ):
420
- quant_compressor = self .quantization_compressor .get (
421
- module .quantization_scheme .format
422
- )
423
- state_dict = quant_compressor .compress (
424
- state_dict ,
425
- names_to_scheme = module_to_scheme ,
426
- show_progress = False ,
427
- compression_device = exec_device ,
428
- )
429
- else :
430
- state_dict = self .quantization_compressor .compress (
431
- state_dict ,
432
- names_to_scheme = module_to_scheme ,
433
- show_progress = False ,
434
- compression_device = exec_device ,
435
- )
436
+ quant_compressor = self .quantization_compressor .get (
437
+ module .quantization_scheme .format
438
+ )
439
+ state_dict = quant_compressor .compress (
440
+ state_dict ,
441
+ names_to_scheme = module_to_scheme ,
442
+ show_progress = False ,
443
+ compression_device = exec_device ,
444
+ )
436
445
437
446
# sparsity second
438
447
if prefix in sparse_compression_targets :
@@ -498,12 +507,13 @@ def decompress_model(self, model: Module):
498
507
499
508
# quantization second
500
509
if prefix in module_to_scheme :
501
- state_dict = (
502
- self .quantization_compressor .decompress_module_from_state_dict (
503
- prefix ,
504
- state_dict ,
505
- scheme = module_to_scheme [prefix ],
506
- )
510
+ quant_compressor = self .quantization_compressor .get (
511
+ module .quantization_scheme .format
512
+ )
513
+ state_dict = quant_compressor .decompress_module_from_state_dict (
514
+ prefix ,
515
+ state_dict ,
516
+ scheme = module_to_scheme [prefix ],
507
517
)
508
518
509
519
# remove any existing parameters
@@ -542,7 +552,9 @@ def compress(
542
552
543
553
if self .quantization_compressor is not None :
544
554
module_to_scheme = map_module_to_scheme (model )
545
- state_dict = self .quantization_compressor .compress (
555
+ # Note - compress only supports one compression format atm
556
+ quant_compressor = next (iter (self .quantization_compressor ))
557
+ state_dict = quant_compressor .compress (
546
558
state_dict ,
547
559
names_to_scheme = module_to_scheme ,
548
560
show_progress = show_progress ,
@@ -596,9 +608,11 @@ def decompress(self, model_path: str, model: Module):
596
608
self .sparsity_compressor is not None
597
609
and self .sparsity_config .format != CompressionFormat .dense .value
598
610
):
611
+ # note - decompress only support one compressor so far
612
+ quant_compressor = next (iter (self .quantization_compressor ))
599
613
params_to_ignore = None
600
614
if self .quantization_compressor is not None :
601
- params_to_ignore = self . quantization_compressor .compression_param_names
615
+ params_to_ignore = quant_compressor .compression_param_names
602
616
# Sparse decompression is applied on the model_path
603
617
# The compressor will try and load any quantization parameters as well
604
618
# params_to_skip_load will skip over quantization params from being loaded
@@ -609,7 +623,7 @@ def decompress(self, model_path: str, model: Module):
609
623
setattr (model , SPARSITY_CONFIG_NAME , self .sparsity_compressor .config )
610
624
sparse_decompressed = True
611
625
612
- if self . quantization_compressor is not None :
626
+ if quant_compressor is not None :
613
627
# Temporarily set quantization status to FROZEN to prevent
614
628
# quantization during apply_quantization_config. This ensures
615
629
# that the dtypes of the weights are not unintentionally updated.
@@ -632,15 +646,15 @@ def decompress(self, model_path: str, model: Module):
632
646
# including initialization
633
647
load_weight_quantization = (
634
648
sparse_decompressed
635
- or isinstance (self . quantization_compressor , DenseCompressor )
649
+ or isinstance (quant_compressor , DenseCompressor )
636
650
),
637
651
)
638
652
639
653
model_path_or_state_dict = (
640
654
model .state_dict () if sparse_decompressed else model_path
641
655
)
642
656
643
- dense_gen = self . quantization_compressor .decompress (
657
+ dense_gen = quant_compressor .decompress (
644
658
model_path_or_state_dict , names_to_scheme = names_to_scheme
645
659
)
646
660
# TODO: all weight quantization params will be moved to the compressor
0 commit comments