@@ -115,6 +115,9 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
115115 `default_{i}` where i is the total number of adapters being loaded.
116116 weight_name (`str`, *optional*, defaults to None):
117117 Name of the serialized state dict file.
118+ low_cpu_mem_usage (`bool`, *optional*):
119+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
120+ weights.
118121
119122 Example:
120123
@@ -142,8 +145,14 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
142145 adapter_name = kwargs .pop ("adapter_name" , None )
143146 _pipeline = kwargs .pop ("_pipeline" , None )
144147 network_alphas = kwargs .pop ("network_alphas" , None )
148+ low_cpu_mem_usage = kwargs .pop ("low_cpu_mem_usage" , False )
145149 allow_pickle = False
146150
151+ if low_cpu_mem_usage and is_peft_version ("<=" , "0.13.0" ):
152+ raise ValueError (
153+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
154+ )
155+
147156 if use_safetensors is None :
148157 use_safetensors = True
149158 allow_pickle = True
@@ -209,6 +218,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
209218 network_alphas = network_alphas ,
210219 adapter_name = adapter_name ,
211220 _pipeline = _pipeline ,
221+ low_cpu_mem_usage = low_cpu_mem_usage ,
212222 )
213223 else :
214224 raise ValueError (
@@ -268,7 +278,9 @@ def _process_custom_diffusion(self, state_dict):
268278
269279 return attn_processors
270280
271- def _process_lora (self , state_dict , unet_identifier_key , network_alphas , adapter_name , _pipeline ):
281+ def _process_lora (
282+ self , state_dict , unet_identifier_key , network_alphas , adapter_name , _pipeline , low_cpu_mem_usage
283+ ):
272284 # This method does the following things:
273285 # 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy
274286 # format. For legacy format no filtering is applied.
@@ -335,9 +347,12 @@ def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter
335347 # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
336348 # otherwise loading LoRA weights will lead to an error
337349 is_model_cpu_offload , is_sequential_cpu_offload = self ._optionally_disable_offloading (_pipeline )
350+ peft_kwargs = {}
351+ if is_peft_version (">=" , "0.13.1" ):
352+ peft_kwargs ["low_cpu_mem_usage" ] = low_cpu_mem_usage
338353
339- inject_adapter_in_model (lora_config , self , adapter_name = adapter_name )
340- incompatible_keys = set_peft_model_state_dict (self , state_dict , adapter_name )
354+ inject_adapter_in_model (lora_config , self , adapter_name = adapter_name , ** peft_kwargs )
355+ incompatible_keys = set_peft_model_state_dict (self , state_dict , adapter_name , ** peft_kwargs )
341356
342357 if incompatible_keys is not None :
343358 # check only for unexpected keys
0 commit comments