File tree Expand file tree Collapse file tree 1 file changed +8
-5
lines changed
Expand file tree Collapse file tree 1 file changed +8
-5
lines changed Original file line number Diff line number Diff line change @@ -63,10 +63,13 @@ def benchmark_quantizer(
6363 mem_efficient_args .update ({
6464 "percdamp" : 0.01 ,
6565 "block_size" : 128 ,
66- })
67- # Create a deep copy of the model using state dict
68- model_clone = type (self .model )(self .model .config )
69- model_clone .load_state_dict (self .model .state_dict ())
66+ }) # Create a deep copy of the model using from_pretrained
67+ config = self .model .config
68+ model_clone = type (self .model )(config )
69+ # Copy weights manually to ensure proper copying
70+ for param_name , param in self .model .state_dict ().items ():
71+ if param_name in model_clone .state_dict ():
72+ model_clone .state_dict ()[param_name ].copy_ (param )
7073
7174 # Initialize quantizer with model copy on CPU
7275 quantizer = quantizer_class (model = model_clone , ** mem_efficient_args )
@@ -76,7 +79,7 @@ def benchmark_quantizer(
7679 quantizer .model = quantizer .model .cuda ()
7780 cal_data = self .calibration_data .cuda ()
7881 else :
79- cal_data = self .calibration_data
82+ cal_data = self .calibration_data . clone ()
8083
8184 # Measure quantization time
8285 start_time = time .time ()
You can’t perform that action at this time.
0 commit comments