@@ -244,15 +244,23 @@ def test_low_cpu_mem_usage_with_loading(self):
244244 "Loading from saved checkpoints with `low_cpu_mem_usage` should give same results." ,
245245 )
246246
247- def test_simple_inference_with_text_lora_and_scale (self ):
247+ def test_simple_inference_with_partial_text_lora (self ):
248248 """
249- Tests a simple inference with lora attached on the text encoder + scale argument
249+ Tests a simple inference with lora attached on the text encoder
250+ with different ranks and some adapters removed
250251 and makes sure it works as expected
251252 """
252- attention_kwargs_name = determine_attention_kwargs_name (self .pipeline_class )
253-
254253 for scheduler_cls in self .scheduler_classes :
255- components , text_lora_config , _ = self .get_dummy_components (scheduler_cls )
254+ components , _ , _ = self .get_dummy_components (scheduler_cls )
255+ # Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
256+ text_lora_config = LoraConfig (
257+ r = 4 ,
258+ rank_pattern = {self .text_encoder_target_modules [i ]: i + 1 for i in range (3 )},
259+ lora_alpha = 4 ,
260+ target_modules = self .text_encoder_target_modules ,
261+ init_lora_weights = False ,
262+ use_dora = False ,
263+ )
256264 pipe = self .pipeline_class (** components )
257265 pipe = pipe .to (torch_device )
258266 pipe .set_progress_bar_config (disable = None )
@@ -263,25 +271,39 @@ def test_simple_inference_with_text_lora_and_scale(self):
263271
264272 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
265273
266- output_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
267- self .assertTrue (
268- not np .allclose (output_lora , output_no_lora , atol = 1e-3 , rtol = 1e-3 ), "Lora should change the output"
269- )
274+ state_dict = {}
275+ if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
276+ # Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder`
277+ # supports missing layers (PR#8324).
278+ state_dict = {
279+ f"text_encoder.{ module_name } " : param
280+ for module_name , param in get_peft_model_state_dict (pipe .text_encoder ).items ()
281+ if "text_model.encoder.layers.4" not in module_name
282+ }
270283
271- attention_kwargs = {attention_kwargs_name : {"scale" : 0.5 }}
272- output_lora_scale = pipe (** inputs , generator = torch .manual_seed (0 ), ** attention_kwargs )[0 ]
284+ if self .has_two_text_encoders or self .has_three_text_encoders :
285+ if "text_encoder_2" in self .pipeline_class ._lora_loadable_modules :
286+ state_dict .update (
287+ {
288+ f"text_encoder_2.{ module_name } " : param
289+ for module_name , param in get_peft_model_state_dict (pipe .text_encoder_2 ).items ()
290+ if "text_model.encoder.layers.4" not in module_name
291+ }
292+ )
273293
294+ output_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
274295 self .assertTrue (
275- not np .allclose (output_lora , output_lora_scale , atol = 1e-3 , rtol = 1e-3 ),
276- "Lora + scale should change the output" ,
296+ not np .allclose (output_lora , output_no_lora , atol = 1e-3 , rtol = 1e-3 ), "Lora should change the output"
277297 )
278298
279- attention_kwargs = {attention_kwargs_name : {"scale" : 0.0 }}
280- output_lora_0_scale = pipe (** inputs , generator = torch .manual_seed (0 ), ** attention_kwargs )[0 ]
299+ # Unload lora and load it back using the pipe.load_lora_weights machinery
300+ pipe .unload_lora_weights ()
301+ pipe .load_lora_weights (state_dict )
281302
303+ output_partial_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
282304 self .assertTrue (
283- np .allclose (output_no_lora , output_lora_0_scale , atol = 1e-3 , rtol = 1e-3 ),
284- "Lora + 0 scale should lead to same result as no LoRA " ,
305+ not np .allclose (output_partial_lora , output_lora , atol = 1e-3 , rtol = 1e-3 ),
306+ "Removing adapters should change the output " ,
285307 )
286308
287309 @parameterized .expand ([("fused" ,), ("unloaded" ,), ("save_load" ,)])
@@ -369,66 +391,57 @@ def test_lora_text_encoder_actions(self, action):
369391 "Loading from a saved checkpoint should yield the same result as the original LoRA." ,
370392 )
371393
372- def test_simple_inference_with_partial_text_lora (self ):
394+ @parameterized .expand (
395+ [
396+ ("text_encoder_only" ,),
397+ ("text_and_denoiser" ,),
398+ ]
399+ )
400+ def test_lora_scaling (self , lora_components_to_add ):
373401 """
374- Tests a simple inference with lora attached on the text encoder
375- with different ranks and some adapters removed
376- and makes sure it works as expected
402+ Tests inference with LoRA scaling applied via attention_kwargs
403+ for different LoRA configurations.
377404 """
405+ if lora_components_to_add == "text_encoder_only" :
406+ if not any ("text_encoder" in k for k in self .pipeline_class ._lora_loadable_modules ):
407+ pytest .skip (
408+ "Test not supported for {self.__class__.__name__} since there is not text encoder in the LoRA loadable modules."
409+ )
410+ attention_kwargs_name = determine_attention_kwargs_name (self .pipeline_class )
411+
378412 for scheduler_cls in self .scheduler_classes :
379- components , _ , _ = self .get_dummy_components (scheduler_cls )
380- # Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
381- text_lora_config = LoraConfig (
382- r = 4 ,
383- rank_pattern = {self .text_encoder_target_modules [i ]: i + 1 for i in range (3 )},
384- lora_alpha = 4 ,
385- target_modules = self .text_encoder_target_modules ,
386- init_lora_weights = False ,
387- use_dora = False ,
413+ pipe , inputs , output_no_lora , text_lora_config , denoiser_lora_config = (
414+ self ._setup_pipeline_and_get_base_output (scheduler_cls )
388415 )
389- pipe = self .pipeline_class (** components )
390- pipe = pipe .to (torch_device )
391- pipe .set_progress_bar_config (disable = None )
392- _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
393-
394- output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
395- self .assertTrue (output_no_lora .shape == self .output_shape )
396-
397- pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
398-
399- state_dict = {}
400- if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
401- # Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder`
402- # supports missing layers (PR#8324).
403- state_dict = {
404- f"text_encoder.{ module_name } " : param
405- for module_name , param in get_peft_model_state_dict (pipe .text_encoder ).items ()
406- if "text_model.encoder.layers.4" not in module_name
407- }
408416
409- if self .has_two_text_encoders or self .has_three_text_encoders :
410- if "text_encoder_2" in self .pipeline_class ._lora_loadable_modules :
411- state_dict .update (
412- {
413- f"text_encoder_2.{ module_name } " : param
414- for module_name , param in get_peft_model_state_dict (pipe .text_encoder_2 ).items ()
415- if "text_model.encoder.layers.4" not in module_name
416- }
417- )
417+ # Add LoRA components based on the parameterization
418+ if lora_components_to_add == "text_encoder_only" :
419+ pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
420+ elif lora_components_to_add == "text_and_denoiser" :
421+ pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
422+ else :
423+ raise ValueError (f"Unknown `lora_components_to_add`: { lora_components_to_add } " )
418424
425+ # 1. Test base LoRA output
419426 output_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
420- self .assertTrue (
421- not np .allclose (output_lora , output_no_lora , atol = 1e-3 , rtol = 1e-3 ), "Lora should change the output"
427+ self .assertFalse (
428+ np .allclose (output_lora , output_no_lora , atol = 1e-3 , rtol = 1e-3 ), "LoRA should change the output. "
422429 )
423430
424- # Unload lora and load it back using the pipe.load_lora_weights machinery
425- pipe .unload_lora_weights ()
426- pipe .load_lora_weights (state_dict )
431+ # 2. Test with a scale of 0.5
432+ attention_kwargs = {attention_kwargs_name : {"scale" : 0.5 }}
433+ output_lora_scale = pipe (** inputs , generator = torch .manual_seed (0 ), ** attention_kwargs )[0 ]
434+ self .assertFalse (
435+ np .allclose (output_lora , output_lora_scale , atol = 1e-3 , rtol = 1e-3 ),
436+ "Using a LoRA scale should change the output." ,
437+ )
427438
428- output_partial_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
439+ # 3. Test with a scale of 0.0, which should be identical to no LoRA
440+ attention_kwargs = {attention_kwargs_name : {"scale" : 0.0 }}
441+ output_lora_0_scale = pipe (** inputs , generator = torch .manual_seed (0 ), ** attention_kwargs )[0 ]
429442 self .assertTrue (
430- not np .allclose (output_partial_lora , output_lora , atol = 1e-3 , rtol = 1e-3 ),
431- "Removing adapters should change the output " ,
443+ np .allclose (output_no_lora , output_lora_0_scale , atol = 1e-3 , rtol = 1e-3 ),
444+ "Using a LoRA scale of 0.0 should be the same as no LoRA. " ,
432445 )
433446
434447 def test_simple_inference_with_text_denoiser_lora_save_load (self ):
@@ -469,52 +482,6 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
469482 "Loading from saved checkpoints should give same results." ,
470483 )
471484
472- def test_simple_inference_with_text_denoiser_lora_and_scale (self ):
473- """
474- Tests a simple inference with lora attached on the text encoder + Unet + scale argument
475- and makes sure it works as expected
476- """
477- attention_kwargs_name = determine_attention_kwargs_name (self .pipeline_class )
478-
479- for scheduler_cls in self .scheduler_classes :
480- components , text_lora_config , denoiser_lora_config = self .get_dummy_components (scheduler_cls )
481- pipe = self .pipeline_class (** components )
482- pipe = pipe .to (torch_device )
483- pipe .set_progress_bar_config (disable = None )
484- _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
485-
486- output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
487- self .assertTrue (output_no_lora .shape == self .output_shape )
488-
489- pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
490-
491- output_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
492- self .assertTrue (
493- not np .allclose (output_lora , output_no_lora , atol = 1e-3 , rtol = 1e-3 ), "Lora should change the output"
494- )
495-
496- attention_kwargs = {attention_kwargs_name : {"scale" : 0.5 }}
497- output_lora_scale = pipe (** inputs , generator = torch .manual_seed (0 ), ** attention_kwargs )[0 ]
498-
499- self .assertTrue (
500- not np .allclose (output_lora , output_lora_scale , atol = 1e-3 , rtol = 1e-3 ),
501- "Lora + scale should change the output" ,
502- )
503-
504- attention_kwargs = {attention_kwargs_name : {"scale" : 0.0 }}
505- output_lora_0_scale = pipe (** inputs , generator = torch .manual_seed (0 ), ** attention_kwargs )[0 ]
506-
507- self .assertTrue (
508- np .allclose (output_no_lora , output_lora_0_scale , atol = 1e-3 , rtol = 1e-3 ),
509- "Lora + 0 scale should lead to same result as no LoRA" ,
510- )
511-
512- if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
513- self .assertTrue (
514- pipe .text_encoder .text_model .encoder .layers [0 ].self_attn .q_proj .scaling ["default" ] == 1.0 ,
515- "The scaling parameter has not been correctly restored!" ,
516- )
517-
518485 def test_simple_inference_with_text_lora_denoiser_fused (self ):
519486 """
520487 Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
0 commit comments