@@ -41,6 +41,7 @@ def __init__(
4141 self .offload_mode = None
4242 self .model_names = []
4343 self ._offload_param_dict = {}
44+ self .offload_to_disk = False
4445
4546 @classmethod
4647 def from_pretrained (cls , model_path_or_config : str | BaseConfig ) -> "BasePipeline" :
@@ -228,19 +229,23 @@ def eval(self):
228229 model .eval ()
229230 return self
230231
231- def enable_cpu_offload (self , offload_mode : str ):
232- valid_offload_mode = ("cpu_offload" , "sequential_cpu_offload" )
232+ def enable_cpu_offload (self , offload_mode : str | None , offload_to_disk : bool = False ):
233+ valid_offload_mode = ("cpu_offload" , "sequential_cpu_offload" , "disable" , None )
233234 if offload_mode not in valid_offload_mode :
234235 raise ValueError (f"offload_mode must be one of { valid_offload_mode } , but got { offload_mode } " )
235236 if self .device == "cpu" or self .device == "mps" :
236237 logger .warning ("must set an non cpu device for pipeline before calling enable_cpu_offload" )
237238 return
238- if offload_mode == "cpu_offload" :
239+ if offload_mode is None or offload_mode == "disable" :
240+ self ._disable_offload ()
241+ elif offload_mode == "cpu_offload" :
239242 self ._enable_model_cpu_offload ()
240243 elif offload_mode == "sequential_cpu_offload" :
241244 self ._enable_sequential_cpu_offload ()
245+ self .offload_to_disk = offload_to_disk
242246
243- def _enable_model_cpu_offload (self ):
247+
248+ def _enable_model_cpu_offload (self ):
244249 for model_name in self .model_names :
245250 model = getattr (self , model_name )
246251 if model is not None :
@@ -253,13 +258,23 @@ def _enable_sequential_cpu_offload(self):
253258 if model is not None :
254259 enable_sequential_cpu_offload (model , self .device )
255260 self .offload_mode = "sequential_cpu_offload"
261+
262+ def _disable_offload (self ):
263+ self .offload_mode = None
264+ self ._offload_param_dict = {}
265+ for model_name in self .model_names :
266+ model = getattr (self , model_name )
267+ if model is not None :
268+ model .to (self .device )
269+
256270
257271 def enable_fp8_autocast (
258272 self , model_names : List [str ], compute_dtype : torch .dtype = torch .bfloat16 , use_fp8_linear : bool = False
259273 ):
260274 for model_name in model_names :
261275 model = getattr (self , model_name )
262276 if model is not None :
277+ model .to (device = self .device , dtype = torch .float8_e4m3fn )
263278 enable_fp8_autocast (model , compute_dtype , use_fp8_linear )
264279 self .fp8_autocast_enabled = True
265280
@@ -282,10 +297,26 @@ def load_models_to_device(self, load_model_names: List[str] | None = None):
282297 # load the needed models to device
283298 for model_name in load_model_names :
284299 model = getattr (self , model_name )
300+ if model is None :
301+ raise ValueError (f"model { model_name } is not loaded, maybe this model has been destroyed by model_lifecycle_finish function with offload_to_disk=True" )
285302 if model is not None and (p := next (model .parameters (), None )) is not None and p .device .type != self .device :
286303 model .to (self .device )
287304 # fresh the cuda cache
288305 empty_cache ()
289306
307+ def model_lifecycle_finish (self , model_names : List [str ] | None = None ):
308+ if not self .offload_to_disk or self .offload_mode is None :
309+ return
310+ for model_name in model_names :
311+ model = getattr (self , model_name )
312+ del model
313+ if model_name in self ._offload_param_dict :
314+ del self ._offload_param_dict [model_name ]
315+ setattr (self , model_name , None )
316+ print (f"model { model_name } has been deleted from memory" )
317+ logger .info (f"model { model_name } has been deleted from memory" )
318+ empty_cache ()
319+
320+
290321 def compile (self ):
291322 raise NotImplementedError (f"{ self .__class__ .__name__ } does not support compile" )
0 commit comments