@@ -392,15 +392,18 @@ def compress_model(self, model: Module):
392
392
for prefix , module in tqdm (model .named_modules (), desc = "Compressing model" ):
393
393
394
394
if prefix in module_to_scheme or prefix in sparse_compression_targets :
395
- module_device = get_execution_device (module ). type
396
- is_meta = ( module_device == "meta" )
395
+ module_device = get_execution_device (module )
396
+ is_meta = module_device . type == "meta"
397
397
398
398
exec_device = "meta" if is_meta else "cpu"
399
399
onloading_device = "meta" if is_meta else module_device
400
400
401
401
# in the future, support compression on same device
402
402
with align_module_device (module , execution_device = exec_device ):
403
- state_dict = module .state_dict (prefix = f"{ prefix } ." )
403
+ state_dict = {
404
+ f"{ prefix } .{ name } " : param
405
+ for name , param in module .named_parameters (recurse = False )
406
+ }
404
407
405
408
# quantization first
406
409
if prefix in module_to_scheme :
@@ -421,7 +424,7 @@ def compress_model(self, model: Module):
421
424
422
425
# remove any existing parameters
423
426
offload_device = get_offloaded_device (module )
424
- for name , _ in list (module .named_parameters ()):
427
+ for name , _ in list (module .named_parameters (recurse = False )):
425
428
delete_offload_parameter (module , name )
426
429
427
430
# replace with compressed parameters
@@ -458,7 +461,10 @@ def decompress_model(self, model: Module):
458
461
if prefix in module_to_scheme or prefix in sparse_compression_targets :
459
462
# in the future, support decompression on same device
460
463
with align_module_device (module , execution_device = "cpu" ):
461
- state_dict = module .state_dict (prefix = f"{ prefix } ." )
464
+ state_dict = {
465
+ f"{ prefix } .{ name } " : param
466
+ for name , param in module .named_parameters (recurse = False )
467
+ }
462
468
463
469
# sparsity first
464
470
if prefix in sparse_compression_targets :
@@ -483,7 +489,7 @@ def decompress_model(self, model: Module):
483
489
# remove any existing parameters
484
490
exec_device = get_execution_device (module )
485
491
offload_device = get_offloaded_device (module )
486
- for name , _ in list (module .named_parameters ()):
492
+ for name , _ in list (module .named_parameters (recurse = False )):
487
493
delete_offload_parameter (module , name )
488
494
489
495
# replace with decompressed parameters
@@ -747,12 +753,16 @@ def _replace_weights(self, dense_weight_generator, model: Module):
747
753
748
754
def map_module_to_scheme (model : Module ) -> Dict [str , QuantizationScheme ]:
749
755
"""
750
- Returns a dictionary which maps quantized module names to their quantization schemes
756
+ Returns a dictionary which maps quantized module names to their quantization
757
+ schemes. Only includes modules with weight quantization
751
758
"""
752
759
return {
753
760
fix_fsdp_module_name (name ): module .quantization_scheme
754
761
for name , module in model .named_modules ()
755
- if is_module_quantized (module )
762
+ if (
763
+ hasattr (module , "quantization_scheme" )
764
+ and module .quantization_scheme .weights is not None
765
+ )
756
766
}
757
767
758
768
0 commit comments