@@ -84,28 +84,28 @@ def _copy_model(self) -> PreTrainedModel:
8484 """Create a deep copy of the model, ensuring it's on CPU initially."""
8585 try :
8686 print ("Creating new model instance..." )
87- print ("Creating new model instance from config..." )
88- # Get model configuration
89- config = AutoConfig .from_pretrained (
90- self .model .config ._name_or_path , # Use the original model's name or path
91- trust_remote_code = True # Add trust_remote_code=True if needed for custom models
92- )
93-
94- # Create new model instance on CPU
95- new_model = AutoModelForCausalLM .from_config (config , trust_remote_code = True ).to ("cpu" )
96-
97- print ("Copying model parameters (state_dict) to CPU..." )
98- # Copy state dict from the original self.model (which is on CPU)
99- with torch .no_grad ():
100- state_dict_cpu = {k : v .cpu () for k , v in self .model .state_dict ().items ()}
101- new_model .load_state_dict (state_dict_cpu , assign = True , strict = True )
102- del state_dict_cpu # Free memory
87+ print ("Creating new model instance from config..." )
88+ # Get model configuration
89+ config = AutoConfig .from_pretrained (
90+ self .model .config ._name_or_path , # Use the original model's name or path
91+ trust_remote_code = True # Add trust_remote_code=True if needed for custom models
92+ )
10393
104- return new_model
105-
106- except Exception as e :
107- print (f"Detailed error in _copy_model: { type (e ).__name__ } : { e } " )
108- raise RuntimeError (f"Failed to copy model: { str (e )} " )
94+ # Create new model instance on CPU
95+ new_model = AutoModelForCausalLM .from_config (config , trust_remote_code = True ).to ("cpu" )
96+
97+ print ("Copying model parameters (state_dict) to CPU..." )
98+ # Copy state dict from the original self.model (which is on CPU)
99+ with torch .no_grad ():
100+ state_dict_cpu = {k : v .cpu () for k , v in self .model .state_dict ().items ()}
101+ new_model .load_state_dict (state_dict_cpu , assign = True , strict = True )
102+ del state_dict_cpu # Free memory
103+
104+ return new_model
105+
106+ except Exception as e :
107+ print (f"Detailed error in _copy_model: { type (e ).__name__ } : { e } " )
108+ raise RuntimeError (f"Failed to copy model: { str (e )} " )
109109
110110 def benchmark_quantizer (
111111 self ,
@@ -541,11 +541,11 @@ def plot_comparison(self, save_path: str = None):
541541
542542 plt .close (fig ) # Close the figure to free memory
543543 # No explicit cleanup for self.model or self.calibration_data here, they are persistent.
544- if self .pynvml_available and self .nvml_handle :
545- # nvmlShutdown is typically called once when the application exits, not per benchmark.
546- # For now, do not shut down NVML here to allow multiple calls to benchmark_quantizer or run_all_benchmarks.
547- # Consider adding a __del__ or close() method to QuantizationBenchmark for global NVML shutdown.
548- pass
544+ if self .pynvml_available and self .nvml_handle :
545+ # nvmlShutdown is typically called once when the application exits, not per benchmark.
546+ # For now, do not shut down NVML here to allow multiple calls to benchmark_quantizer or run_all_benchmarks.
547+ # Consider adding a __del__ or close() method to QuantizationBenchmark for global NVML shutdown.
548+ pass
549549
550550 def __del__ (self ):
551551 # Destructor to ensure NVML is shut down when the object is deleted or program exits.
0 commit comments