49
49
get_offloaded_device ,
50
50
get_safetensors_folder ,
51
51
has_offloaded_params ,
52
- merge_names ,
53
52
patch_attr ,
54
53
register_offload_parameter ,
55
54
update_parameter_data ,
@@ -226,7 +225,8 @@ def parse_sparsity_config(
226
225
s_config = compression_config .sparsity_config
227
226
return s_config .model_dump () if s_config is not None else None
228
227
229
- return compression_config .get (SPARSITY_CONFIG_NAME , None )
228
+ # explicitly return None if {} in config
229
+ return compression_config .get (SPARSITY_CONFIG_NAME , None ) or None
230
230
231
231
@staticmethod
232
232
def parse_quantization_config (
@@ -316,117 +316,11 @@ def __init__(
316
316
317
317
self .quantization_compressor = {}
318
318
for format in self .compression_formats :
319
- self .quantization_compressor [
320
- format
321
- ] = BaseCompressor .load_from_registry (
322
- format , config = quantization_config
323
- )
324
-
325
- # ----- used by hf quantizer ----- #
326
-
327
- def get_missing_module_keys (self , model : Module ) -> List [str ]:
328
- """
329
- Identifies the expected missing weight keys in the compressed state_dict.
330
-
331
- When a model undergoes sparsity or quantization compression, certain
332
- weight tensors may be absent from the checkpoint by virtue of compression.
333
- This function determines which weight keys are missing based on the
334
- applied compression techniques.
335
-
336
- :param model: The PyTorch model to check for missing keys.
337
- :return: A list of missing keys expected in the compressed state_dict.
338
- """
339
- missing_keys = set ()
340
-
341
- # Determine missing keys due to sparsity compression
342
- if (
343
- self .sparsity_compressor
344
- and self .sparsity_config .format != CompressionFormat .dense .value
345
- ):
346
- sparse_targets = match_named_modules (
347
- model = model ,
348
- targets = self .sparsity_config .targets ,
349
- ignore = self .sparsity_config .ignore ,
350
- )
351
-
352
- missing_keys .update (
353
- merge_names (target_name , "weight" )
354
- for target_name , _module in sparse_targets
355
- )
356
-
357
- # Determine missing keys due to pack quantization
358
- if (
359
- self .quantization_compressor
360
- and self .quantization_config .format
361
- == CompressionFormat .pack_quantized .value
362
- ):
363
- for scheme in self .quantization_config .config_groups .values ():
364
- quant_targets = match_named_modules (
365
- model = model ,
366
- targets = scheme .targets ,
367
- ignore = self .quantization_config .ignore ,
368
- )
369
- missing_keys .update (
370
- merge_names (target_name , "weight" )
371
- for target_name , _module in quant_targets
372
- )
373
-
374
- return list (missing_keys )
375
-
376
- def get_unexpected_file_keys (self , model : Module ) -> List [str ]:
377
- """
378
- Identifies extra keys introduced by the compression process in the
379
- compressed state_dict that are not expected by the model graph.
380
-
381
- During sparsity or quantization compression, additional metadata or
382
- auxiliary parameters may be stored in the checkpoint, which do not
383
- correspond to any parameter in the original model. These keys are
384
- typically introduced to support the reconstruction of compressed weights.
385
-
386
- For example, Sparse24Bitmask compression may introduce keys such as
387
- 'compressed', 'bitmask', and 'shape' in the checkpoint, which are
388
- not part of the original model parameters.
389
-
390
- :param model: The PyTorch model to check for unexpected keys.
391
- :return: A list of extra keys introduced by the compression process
392
- that are not expected by the model.
393
- """
394
-
395
- unexpected_keys = set ()
396
-
397
- # Identify unexpected keys from sparsity compression
398
- if (
399
- self .sparsity_compressor
400
- and self .sparsity_config .format != CompressionFormat .dense .value
401
- ):
402
- sparse_targets = match_named_modules (
403
- model = model ,
404
- targets = self .sparsity_config .targets ,
405
- ignore = self .sparsity_config .ignore ,
406
- )
407
- unexpected_keys .update (
408
- merge_names (target_name , param )
409
- for target_name , _module in sparse_targets
410
- for param in self .sparsity_compressor .compression_param_names
411
- )
412
-
413
- # Identify unexpected keys from quantization compression
414
- if self .quantization_compressor :
415
- for scheme in self .quantization_config .config_groups .values ():
416
- quant_targets = match_named_modules (
417
- model = model ,
418
- targets = scheme .targets ,
419
- ignore = self .quantization_config .ignore ,
420
- )
421
- for quant_compressor in self .quantization_compressor .values ():
422
- unexpected_keys .update (
423
- merge_names (target_name , param )
424
- for target_name , _module in quant_targets
425
- for param in quant_compressor .compression_param_names
426
- if param != "weight"
319
+ self .quantization_compressor [format ] = (
320
+ BaseCompressor .load_from_registry (
321
+ format , config = quantization_config
427
322
)
428
-
429
- return list (unexpected_keys )
323
+ )
430
324
431
325
# ----- model memory compression/decompression pathways ----- #
432
326
@@ -716,17 +610,16 @@ def decompress(self, model_path: str, model: Module):
716
610
# Load activation scales/zp or any other quantization parameters
717
611
# Conditionally load the weight quantization parameters if we have a
718
612
# dense compressor or if a sparsity compressor has already been applied
613
+ load_weight_qparams = sparse_decompressed or isinstance (
614
+ quant_compressor , DenseCompressor
615
+ )
719
616
load_pretrained_quantization_parameters (
720
617
model ,
721
618
model_path ,
722
619
# TODO: all weight quantization params will be moved to the
723
620
# compressor in a follow-up including initialization
724
- load_weight_quantization = (
725
- sparse_decompressed
726
- or isinstance (quant_compressor , DenseCompressor )
727
- ),
621
+ load_weight_qparams = load_weight_qparams ,
728
622
)
729
-
730
623
model_path_or_state_dict = (
731
624
model .state_dict () if sparse_decompressed else model_path
732
625
)
@@ -736,7 +629,9 @@ def decompress(self, model_path: str, model: Module):
736
629
)
737
630
# TODO: all weight quantization params will be moved to the compressor
738
631
# to prevent duplicate parameter updates in update_parameter_data
739
- self ._replace_weights (dense_gen , model )
632
+ self ._replace_weights (
633
+ dense_gen , model , load_weight_qparams = not load_weight_qparams
634
+ )
740
635
741
636
def freeze_quantization_status (module ):
742
637
module .quantization_status = QuantizationStatus .FROZEN
@@ -823,7 +718,9 @@ def _replace_sparsity_weights(self, dense_weight_generator, model: Module):
823
718
param = torch .nn .Parameter (data .to (device ), requires_grad = requires_grad )
824
719
register_offload_parameter (module , param_name , param )
825
720
826
- def _replace_weights (self , dense_weight_generator , model : Module ):
721
+ def _replace_weights (
722
+ self , dense_weight_generator , model : Module , load_weight_qparams : bool = True
723
+ ):
827
724
"""
828
725
Replace the weights of the model with the
829
726
provided dense weights.
@@ -851,6 +748,7 @@ def _replace_weights(self, dense_weight_generator, model: Module):
851
748
# decompression in init to be consistent with loading which happens
852
749
# later as well however, update_data does a good shape check -
853
750
# should be moved to the compressor
751
+
854
752
if param_name == "weight" :
855
753
delattr (module , param_name )
856
754
requires_grad = param_data .dtype in (
@@ -862,7 +760,7 @@ def _replace_weights(self, dense_weight_generator, model: Module):
862
760
param_data .to (device ), requires_grad = requires_grad
863
761
)
864
762
register_offload_parameter (module , param_name , param )
865
- else :
763
+ elif load_weight_qparams :
866
764
# Should already be registered to the correct device for
867
765
# for scales/zero-points
868
766
update_parameter_data (module , param_data , param_name )
0 commit comments