@@ -251,61 +251,6 @@ def get_dummy_tokens(self):
251251 prepared_inputs ["input_ids" ] = inputs
252252 return prepared_inputs
253253
254- def _get_lora_state_dicts (self , modules_to_save ):
255- state_dicts = {}
256- for module_name , module in modules_to_save .items ():
257- if module is not None :
258- state_dicts [f"{ module_name } _lora_layers" ] = get_peft_model_state_dict (module )
259- return state_dicts
260-
261- def _get_lora_adapter_metadata (self , modules_to_save ):
262- metadatas = {}
263- for module_name , module in modules_to_save .items ():
264- if module is not None :
265- metadatas [f"{ module_name } _lora_adapter_metadata" ] = module .peft_config ["default" ].to_dict ()
266- return metadatas
267-
268- def _get_modules_to_save (self , pipe , has_denoiser = False ):
269- modules_to_save = {}
270- lora_loadable_modules = self .pipeline_class ._lora_loadable_modules
271-
272- if (
273- "text_encoder" in lora_loadable_modules
274- and hasattr (pipe , "text_encoder" )
275- and getattr (pipe .text_encoder , "peft_config" , None ) is not None
276- ):
277- modules_to_save ["text_encoder" ] = pipe .text_encoder
278-
279- if (
280- "text_encoder_2" in lora_loadable_modules
281- and hasattr (pipe , "text_encoder_2" )
282- and getattr (pipe .text_encoder_2 , "peft_config" , None ) is not None
283- ):
284- modules_to_save ["text_encoder_2" ] = pipe .text_encoder_2
285-
286- if has_denoiser :
287- if "unet" in lora_loadable_modules and hasattr (pipe , "unet" ):
288- modules_to_save ["unet" ] = pipe .unet
289-
290- if "transformer" in lora_loadable_modules and hasattr (pipe , "transformer" ):
291- modules_to_save ["transformer" ] = pipe .transformer
292-
293- return modules_to_save
294-
295- def _get_exclude_modules (self , pipe ):
296- from diffusers .utils .peft_utils import _derive_exclude_modules
297-
298- modules_to_save = self ._get_modules_to_save (pipe , has_denoiser = True )
299- denoiser = "unet" if self .unet_kwargs is not None else "transformer"
300- modules_to_save = {k : v for k , v in modules_to_save .items () if k == denoiser }
301- denoiser_lora_state_dict = self ._get_lora_state_dicts (modules_to_save )[f"{ denoiser } _lora_layers" ]
302- pipe .unload_lora_weights ()
303- denoiser_state_dict = pipe .unet .state_dict () if self .unet_kwargs is not None else pipe .transformer .state_dict ()
304- exclude_modules = _derive_exclude_modules (
305- denoiser_state_dict , denoiser_lora_state_dict , adapter_name = "default"
306- )
307- return exclude_modules
308-
309254 def add_adapters_to_pipeline (self , pipe , text_lora_config = None , denoiser_lora_config = None , adapter_name = "default" ):
310255 if text_lora_config is not None :
311256 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
@@ -2408,3 +2353,58 @@ def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
24082353 # materializes the test methods on invocation which cannot be overridden.
24092354 return
24102355 self ._test_group_offloading_inference_denoiser (offload_type , use_stream )
2356+
2357+ def _get_lora_state_dicts (self , modules_to_save ):
2358+ state_dicts = {}
2359+ for module_name , module in modules_to_save .items ():
2360+ if module is not None :
2361+ state_dicts [f"{ module_name } _lora_layers" ] = get_peft_model_state_dict (module )
2362+ return state_dicts
2363+
2364+ def _get_lora_adapter_metadata (self , modules_to_save ):
2365+ metadatas = {}
2366+ for module_name , module in modules_to_save .items ():
2367+ if module is not None :
2368+ metadatas [f"{ module_name } _lora_adapter_metadata" ] = module .peft_config ["default" ].to_dict ()
2369+ return metadatas
2370+
2371+ def _get_modules_to_save (self , pipe , has_denoiser = False ):
2372+ modules_to_save = {}
2373+ lora_loadable_modules = self .pipeline_class ._lora_loadable_modules
2374+
2375+ if (
2376+ "text_encoder" in lora_loadable_modules
2377+ and hasattr (pipe , "text_encoder" )
2378+ and getattr (pipe .text_encoder , "peft_config" , None ) is not None
2379+ ):
2380+ modules_to_save ["text_encoder" ] = pipe .text_encoder
2381+
2382+ if (
2383+ "text_encoder_2" in lora_loadable_modules
2384+ and hasattr (pipe , "text_encoder_2" )
2385+ and getattr (pipe .text_encoder_2 , "peft_config" , None ) is not None
2386+ ):
2387+ modules_to_save ["text_encoder_2" ] = pipe .text_encoder_2
2388+
2389+ if has_denoiser :
2390+ if "unet" in lora_loadable_modules and hasattr (pipe , "unet" ):
2391+ modules_to_save ["unet" ] = pipe .unet
2392+
2393+ if "transformer" in lora_loadable_modules and hasattr (pipe , "transformer" ):
2394+ modules_to_save ["transformer" ] = pipe .transformer
2395+
2396+ return modules_to_save
2397+
2398+ def _get_exclude_modules (self , pipe ):
2399+ from diffusers .utils .peft_utils import _derive_exclude_modules
2400+
2401+ modules_to_save = self ._get_modules_to_save (pipe , has_denoiser = True )
2402+ denoiser = "unet" if self .unet_kwargs is not None else "transformer"
2403+ modules_to_save = {k : v for k , v in modules_to_save .items () if k == denoiser }
2404+ denoiser_lora_state_dict = self ._get_lora_state_dicts (modules_to_save )[f"{ denoiser } _lora_layers" ]
2405+ pipe .unload_lora_weights ()
2406+ denoiser_state_dict = pipe .unet .state_dict () if self .unet_kwargs is not None else pipe .transformer .state_dict ()
2407+ exclude_modules = _derive_exclude_modules (
2408+ denoiser_state_dict , denoiser_lora_state_dict , adapter_name = "default"
2409+ )
2410+ return exclude_modules
0 commit comments