@@ -229,7 +229,7 @@ def eval(self):
229229 model .eval ()
230230 return self
231231
232- def enable_cpu_offload (self , offload_mode : str | None , offload_to_disk :bool = False ):
232+ def enable_cpu_offload (self , offload_mode : str | None , offload_to_disk : bool = False ):
233233 valid_offload_mode = ("cpu_offload" , "sequential_cpu_offload" , "disable" , None )
234234 if offload_mode not in valid_offload_mode :
235235 raise ValueError (f"offload_mode must be one of { valid_offload_mode } , but got { offload_mode } " )
@@ -244,8 +244,7 @@ def enable_cpu_offload(self, offload_mode: str | None, offload_to_disk:bool = Fa
244244 self ._enable_sequential_cpu_offload ()
245245 self .offload_to_disk = offload_to_disk
246246
247-
248- def _enable_model_cpu_offload (self ):
247+ def _enable_model_cpu_offload (self ):
249248 for model_name in self .model_names :
250249 model = getattr (self , model_name )
251250 if model is not None :
@@ -258,23 +257,22 @@ def _enable_sequential_cpu_offload(self):
258257 if model is not None :
259258 enable_sequential_cpu_offload (model , self .device )
260259 self .offload_mode = "sequential_cpu_offload"
261-
260+
262261 def _disable_offload (self ):
263- self .offload_mode = None
264- self ._offload_param_dict = {}
262+ self .offload_mode = None
263+ self ._offload_param_dict = {}
265264 for model_name in self .model_names :
266265 model = getattr (self , model_name )
267266 if model is not None :
268267 model .to (self .device )
269268
270-
271269 def enable_fp8_autocast (
272270 self , model_names : List [str ], compute_dtype : torch .dtype = torch .bfloat16 , use_fp8_linear : bool = False
273271 ):
274272 for model_name in model_names :
275273 model = getattr (self , model_name )
276274 if model is not None :
277- model .to (device = self . device , dtype = torch .float8_e4m3fn )
275+ model .to (dtype = torch .float8_e4m3fn )
278276 enable_fp8_autocast (model , compute_dtype , use_fp8_linear )
279277 self .fp8_autocast_enabled = True
280278
@@ -298,15 +296,17 @@ def load_models_to_device(self, load_model_names: List[str] | None = None):
298296 for model_name in load_model_names :
299297 model = getattr (self , model_name )
300298 if model is None :
301- raise ValueError (f"model { model_name } is not loaded, maybe this model has been destroyed by model_lifecycle_finish function with offload_to_disk=True" )
299+ raise ValueError (
300+ f"model { model_name } is not loaded, maybe this model has been destroyed by model_lifecycle_finish function with offload_to_disk=True"
301+ )
302302 if model is not None and (p := next (model .parameters (), None )) is not None and p .device .type != self .device :
303303 model .to (self .device )
304304 # fresh the cuda cache
305305 empty_cache ()
306306
307307 def model_lifecycle_finish (self , model_names : List [str ] | None = None ):
308308 if not self .offload_to_disk or self .offload_mode is None :
309- return
309+ return
310310 for model_name in model_names :
311311 model = getattr (self , model_name )
312312 del model
@@ -316,7 +316,6 @@ def model_lifecycle_finish(self, model_names: List[str] | None = None):
316316 print (f"model { model_name } has been deleted from memory" )
317317 logger .info (f"model { model_name } has been deleted from memory" )
318318 empty_cache ()
319-
320-
319+
321320 def compile (self ):
322321 raise NotImplementedError (f"{ self .__class__ .__name__ } does not support compile" )
0 commit comments