@@ -129,24 +129,17 @@ class PeftLoraLoaderMixinTests:
129129 text_encoder_target_modules = ["q_proj" , "k_proj" , "v_proj" , "out_proj" ]
130130 denoiser_target_modules = ["to_q" , "to_k" , "to_v" , "to_out.0" ]
131131
132- cached_base_pipe_outputs = {}
133-
134- def setUp (self ):
135- super ().setUp ()
136- self .get_base_pipeline_output ()
137-
138- @classmethod
139- def tearDownClass (cls ):
140- super ().tearDownClass ()
141- cls .cached_base_pipe_outputs .clear ()
142-
143- def get_base_pipeline_output (self ):
144- scheduler_names = [scheduler_cls .__name__ for scheduler_cls in self .scheduler_classes ]
145- if self .cached_base_pipe_outputs and all (k in self .cached_base_pipe_outputs for k in scheduler_names ):
146- return
132+ cached_non_lora_outputs = {}
147133
134+ @pytest .fixture (scope = "class" , autouse = True )
135+ def cache_non_lora_outputs (self ):
136+ """
137+ This fixture will be executed once per test class and will populate
138+ the cached_non_lora_outputs dictionary.
139+ """
148140 for scheduler_cls in self .scheduler_classes :
149- if scheduler_cls .__name__ in self .cached_base_pipe_outputs :
141+ # Check if the output for this scheduler is already cached to avoid re-running
142+ if scheduler_cls .__name__ in self .cached_non_lora_outputs :
150143 continue
151144
152145 components , _ , _ = self .get_dummy_components (scheduler_cls )
@@ -158,7 +151,11 @@ def get_base_pipeline_output(self):
158151 # explicitly.
159152 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
160153 output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
161- self .cached_base_pipe_outputs [scheduler_cls .__name__ ] = output_no_lora
154+ self .cached_non_lora_outputs [scheduler_cls .__name__ ] = output_no_lora
155+
156+ # Ensures that there's no inconsistency when reusing the cache.
157+ yield
158+ self .cached_non_lora_outputs .clear ()
162159
163160 def get_dummy_components (self , scheduler_cls = None , use_dora = False , lora_alpha = None ):
164161 if self .unet_kwargs and self .transformer_kwargs :
@@ -351,7 +348,7 @@ def test_simple_inference(self):
351348 Tests a simple inference and makes sure it works as expected
352349 """
353350 for scheduler_cls in self .scheduler_classes :
354- output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
351+ output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
355352 self .assertTrue (output_no_lora .shape == self .output_shape )
356353
357354 def test_simple_inference_with_text_lora (self ):
@@ -366,7 +363,7 @@ def test_simple_inference_with_text_lora(self):
366363 pipe .set_progress_bar_config (disable = None )
367364 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
368365
369- output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
366+ output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
370367 self .assertTrue (output_no_lora .shape == self .output_shape )
371368
372369 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
@@ -449,7 +446,7 @@ def test_low_cpu_mem_usage_with_loading(self):
449446 pipe .set_progress_bar_config (disable = None )
450447 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
451448
452- output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
449+ output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
453450 self .assertTrue (output_no_lora .shape == self .output_shape )
454451
455452 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
@@ -505,7 +502,7 @@ def test_simple_inference_with_text_lora_and_scale(self):
505502 pipe .set_progress_bar_config (disable = None )
506503 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
507504
508- output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
505+ output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
509506 self .assertTrue (output_no_lora .shape == self .output_shape )
510507
511508 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
@@ -543,7 +540,7 @@ def test_simple_inference_with_text_lora_fused(self):
543540 pipe .set_progress_bar_config (disable = None )
544541 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
545542
546- output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
543+ output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
547544 self .assertTrue (output_no_lora .shape == self .output_shape )
548545
549546 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
@@ -575,7 +572,7 @@ def test_simple_inference_with_text_lora_unloaded(self):
575572 pipe .set_progress_bar_config (disable = None )
576573 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
577574
578- output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
575+ output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
579576 self .assertTrue (output_no_lora .shape == self .output_shape )
580577
581578 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
@@ -610,7 +607,7 @@ def test_simple_inference_with_text_lora_save_load(self):
610607 pipe .set_progress_bar_config (disable = None )
611608 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
612609
613- output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
610+ output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
614611 self .assertTrue (output_no_lora .shape == self .output_shape )
615612
616613 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
@@ -661,7 +658,7 @@ def test_simple_inference_with_partial_text_lora(self):
661658 pipe .set_progress_bar_config (disable = None )
662659 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
663660
664- output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
661+ output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
665662 self .assertTrue (output_no_lora .shape == self .output_shape )
666663
667664 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
@@ -712,7 +709,7 @@ def test_simple_inference_save_pretrained_with_text_lora(self):
712709 pipe .set_progress_bar_config (disable = None )
713710 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
714711
715- output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
712+ output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
716713 self .assertTrue (output_no_lora .shape == self .output_shape )
717714
718715 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
@@ -755,7 +752,7 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
755752 pipe .set_progress_bar_config (disable = None )
756753 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
757754
758- output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
755+ output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
759756 self .assertTrue (output_no_lora .shape == self .output_shape )
760757
761758 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
@@ -796,7 +793,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
796793 pipe .set_progress_bar_config (disable = None )
797794 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
798795
799- output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
796+ output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
800797 self .assertTrue (output_no_lora .shape == self .output_shape )
801798
802799 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
@@ -840,7 +837,7 @@ def test_simple_inference_with_text_lora_denoiser_fused(self):
840837 pipe .set_progress_bar_config (disable = None )
841838 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
842839
843- output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
840+ output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
844841 self .assertTrue (output_no_lora .shape == self .output_shape )
845842
846843 pipe , denoiser = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
@@ -878,7 +875,7 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self):
878875 pipe .set_progress_bar_config (disable = None )
879876 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
880877
881- output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
878+ output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
882879 self .assertTrue (output_no_lora .shape == self .output_shape )
883880
884881 pipe , denoiser = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
@@ -957,7 +954,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self):
957954 pipe .set_progress_bar_config (disable = None )
958955 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
959956
960- output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
957+ output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
961958
962959 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
963960 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1086,7 +1083,7 @@ def test_simple_inference_with_text_denoiser_block_scale(self):
10861083 pipe .set_progress_bar_config (disable = None )
10871084 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
10881085
1089- output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
1086+ output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
10901087
10911088 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
10921089 self .assertTrue (check_if_lora_correctly_set (pipe .text_encoder ), "Lora not correctly set in text encoder" )
@@ -1143,7 +1140,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
11431140 pipe .set_progress_bar_config (disable = None )
11441141 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
11451142
1146- output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
1143+ output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
11471144
11481145 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
11491146 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1306,7 +1303,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self):
13061303 pipe .set_progress_bar_config (disable = None )
13071304 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
13081305
1309- output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
1306+ output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
13101307
13111308 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
13121309 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1400,7 +1397,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
14001397 pipe .set_progress_bar_config (disable = None )
14011398 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
14021399
1403- output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
1400+ output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
14041401
14051402 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
14061403 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1644,7 +1641,7 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(
16441641 pipe .set_progress_bar_config (disable = None )
16451642 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
16461643
1647- output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
1644+ output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
16481645 self .assertTrue (output_no_lora .shape == self .output_shape )
16491646
16501647 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
@@ -1725,7 +1722,7 @@ def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expec
17251722 pipe .set_progress_bar_config (disable = None )
17261723 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
17271724
1728- output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
1725+ output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
17291726 self .assertTrue (output_no_lora .shape == self .output_shape )
17301727
17311728 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
@@ -1780,7 +1777,7 @@ def test_simple_inference_with_dora(self):
17801777 pipe .set_progress_bar_config (disable = None )
17811778 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
17821779
1783- output_no_dora_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
1780+ output_no_dora_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
17841781 self .assertTrue (output_no_dora_lora .shape == self .output_shape )
17851782
17861783 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
@@ -1912,7 +1909,7 @@ def test_logs_info_when_no_lora_keys_found(self):
19121909 pipe .set_progress_bar_config (disable = None )
19131910
19141911 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
1915- original_out = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
1912+ original_out = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
19161913
19171914 no_op_state_dict = {"lora_foo" : torch .tensor (2.0 ), "lora_bar" : torch .tensor (3.0 )}
19181915 logger = logging .get_logger ("diffusers.loaders.peft" )
@@ -1958,7 +1955,7 @@ def test_set_adapters_match_attention_kwargs(self):
19581955 pipe .set_progress_bar_config (disable = None )
19591956 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
19601957
1961- output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
1958+ output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
19621959 self .assertTrue (output_no_lora .shape == self .output_shape )
19631960
19641961 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
@@ -2312,7 +2309,7 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
23122309 pipe = self .pipeline_class (** components ).to (torch_device )
23132310 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
23142311
2315- output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
2312+ output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
23162313 self .assertTrue (output_no_lora .shape == self .output_shape )
23172314
23182315 pipe , _ = self .add_adapters_to_pipeline (
@@ -2362,7 +2359,7 @@ def test_inference_load_delete_load_adapters(self):
23622359 pipe .set_progress_bar_config (disable = None )
23632360 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
23642361
2365- output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
2362+ output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
23662363
23672364 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
23682365 pipe .text_encoder .add_adapter (text_lora_config )
0 commit comments