@@ -43,6 +43,7 @@ def __init__(
4343 self .dtype = dtype
4444 self .offload_mode = None
4545 self .model_names = []
46+ self ._models_offload_params = {}
4647
4748 @classmethod
4849 def from_pretrained (
@@ -288,6 +289,10 @@ def _enable_model_cpu_offload(self):
288289 model = getattr (self , model_name )
289290 if model is not None :
290291 model .to ("cpu" )
292+ self ._models_offload_params [model_name ] = {}
293+ for name , param in model .named_parameters (recurse = True ):
294+ param .data = param .data .pin_memory ()
295+ self ._models_offload_params [model_name ][name ] = param .data
291296 self .offload_mode = "cpu_offload"
292297
293298 def _enable_sequential_cpu_offload (self ):
@@ -321,12 +326,14 @@ def load_models_to_device(self, load_model_names: List[str] | None = None):
321326 for model_name in self .model_names :
322327 if model_name not in load_model_names :
323328 model = getattr (self , model_name )
324- if model is not None and (p := next (model .parameters (), None )) is not None and p .device != "cpu" :
325- model .to ("cpu" )
329+ if model is not None and (p := next (model .parameters (), None )) is not None and p .device != torch .device ("cpu" ):
330+ param_cache = self ._models_offload_params [model_name ]
331+ for name , param in model .named_parameters (recurse = True ):
332+ param .data = param_cache [name ]
326333 # load the needed models to device
327334 for model_name in load_model_names :
328335 model = getattr (self , model_name )
329- if model is not None and (p := next (model .parameters (), None )) is not None and p .device != self .device :
336+ if model is not None and (p := next (model .parameters (), None )) is not None and p .device != torch . device ( self .device ) :
330337 model .to (self .device )
331338 # fresh the cuda cache
332339 empty_cache ()
0 commit comments