@@ -49,6 +49,8 @@ def benchmark_quantizer(
4949 quantizer_args : Dict
5050 ) -> Dict [str , float ]:
5151 """Benchmark a specific quantizer with memory management."""
52+ from transformers import AutoModelForCausalLM
53+
5254 results = {}
5355 try :
5456 self ._clear_memory ()
@@ -63,18 +65,29 @@ def benchmark_quantizer(
6365 mem_efficient_args .update ({
6466 "percdamp" : 0.01 ,
6567 "block_size" : 128 ,
66- }) # Create a deep copy of the model using from_pretrained
67- config = self .model .config
68- model_clone = type (self .model )(config )
69- # Copy weights manually to ensure proper copying
70- for param_name , param in self .model .state_dict ().items ():
71- if param_name in model_clone .state_dict ():
72- model_clone .state_dict ()[param_name ].copy_ (param )
68+ })
69+
70+ print (f"Creating copy of model for { name } ..." )
71+ # Create a fresh model instance from pretrained
72+ model_clone = AutoModelForCausalLM .from_pretrained (
73+ self .model .config ._name_or_path ,
74+ low_cpu_mem_usage = True ,
75+ torch_dtype = torch .float32 ,
76+ device_map = None # Important: disable device map for copying
77+ )
78+
79+ print (f"Copying parameters for { name } ..." )
80+ # Manually copy parameters to ensure proper copying
81+ with torch .no_grad ():
82+ for name , param in self .model .named_parameters ():
83+ if name in model_clone .state_dict ():
84+ # Ensure parameter is on CPU for copying
85+ model_clone .state_dict ()[name ].copy_ (param .cpu ())
7386
74- # Initialize quantizer with model copy on CPU
87+ # Initialize quantizer with model copy
7588 quantizer = quantizer_class (model = model_clone , ** mem_efficient_args )
7689
77- # Move to device for quantization
90+ # Move to appropriate device
7891 if self .device == "cuda" :
7992 quantizer .model = quantizer .model .cuda ()
8093 cal_data = self .calibration_data .cuda ()
@@ -84,6 +97,7 @@ def benchmark_quantizer(
8497 # Measure quantization time
8598 start_time = time .time ()
8699
100+ print (f"Starting quantization for { name } ..." )
87101 if name == "AWQ" :
88102 # AWQ uses batched processing
89103 cal_steps = min (20 , len (cal_data ))
0 commit comments