42
42
load_pretrained_quantization_parameters ,
43
43
)
44
44
from compressed_tensors .quantization .lifecycle import expand_target_names
45
- from compressed_tensors .quantization .utils import (
46
- is_module_quantized ,
47
- iter_named_leaf_modules ,
48
- )
45
+ from compressed_tensors .quantization .utils import is_module_quantized
49
46
from compressed_tensors .utils import (
50
47
align_module_device ,
51
48
delete_offload_parameter ,
@@ -393,9 +390,16 @@ def compress_model(self, model: Module):
393
390
)
394
391
395
392
for prefix , module in tqdm (model .named_modules (), desc = "Compressing model" ):
393
+
396
394
if prefix in module_to_scheme or prefix in sparse_compression_targets :
395
+ module_device = get_execution_device (module ).type
396
+ is_meta = (module_device == "meta" )
397
+
398
+ exec_device = "meta" if is_meta else "cpu"
399
+ onloading_device = "meta" if is_meta else module_device
400
+
397
401
# in the future, support compression on same device
398
- with align_module_device (module , execution_device = "cpu" ):
402
+ with align_module_device (module , execution_device = exec_device ):
399
403
state_dict = module .state_dict (prefix = f"{ prefix } ." )
400
404
401
405
# quantization first
@@ -404,6 +408,7 @@ def compress_model(self, model: Module):
404
408
state_dict ,
405
409
names_to_scheme = module_to_scheme ,
406
410
show_progress = False ,
411
+ compression_device = exec_device ,
407
412
)
408
413
409
414
# sparsity second
@@ -415,15 +420,14 @@ def compress_model(self, model: Module):
415
420
)
416
421
417
422
# remove any existing parameters
418
- exec_device = get_execution_device (module )
419
423
offload_device = get_offloaded_device (module )
420
424
for name , _ in list (module .named_parameters ()):
421
425
delete_offload_parameter (module , name )
422
426
423
427
# replace with compressed parameters
424
428
for name , value in state_dict .items ():
425
429
name = name .removeprefix (f"{ prefix } ." )
426
- value = value .to (exec_device )
430
+ value = value .to (onloading_device )
427
431
param = torch .nn .Parameter (value , requires_grad = False )
428
432
register_offload_parameter (module , name , param , offload_device )
429
433
@@ -747,7 +751,7 @@ def map_module_to_scheme(model: Module) -> Dict[str, QuantizationScheme]:
747
751
"""
748
752
return {
749
753
fix_fsdp_module_name (name ): module .quantization_scheme
750
- for name , module in iter_named_leaf_modules ( model )
754
+ for name , module in model . named_modules ( )
751
755
if is_module_quantized (module )
752
756
}
753
757
0 commit comments