@@ -129,17 +129,24 @@ class PeftLoraLoaderMixinTests:
129129 text_encoder_target_modules = ["q_proj" , "k_proj" , "v_proj" , "out_proj" ]
130130 denoiser_target_modules = ["to_q" , "to_k" , "to_v" , "to_out.0" ]
131131
132- cached_non_lora_outputs = {}
132+ cached_base_pipe_outputs = {}
133+
134+ def setUp (self ):
135+ super ().setUp ()
136+ self .get_base_pipeline_output ()
137+
138+ @classmethod
139+ def tearDownClass (cls ):
140+ super ().tearDownClass ()
141+ cls .cached_base_pipe_outputs .clear ()
142+
143+ def get_base_pipeline_output (self ):
144+ scheduler_names = [scheduler_cls .__name__ for scheduler_cls in self .scheduler_classes ]
145+ if self .cached_base_pipe_outputs and all (k in self .cached_base_pipe_outputs for k in scheduler_names ):
146+ return
133147
134- @pytest .fixture (scope = "class" , autouse = True )
135- def cache_non_lora_outputs (self ):
136- """
137- This fixture will be executed once per test class and will populate
138- the cached_non_lora_outputs dictionary.
139- """
140148 for scheduler_cls in self .scheduler_classes :
141- # Check if the output for this scheduler is already cached to avoid re-running
142- if scheduler_cls .__name__ in self .cached_non_lora_outputs :
149+ if scheduler_cls .__name__ in self .cached_base_pipe_outputs :
143150 continue
144151
145152 components , _ , _ = self .get_dummy_components (scheduler_cls )
@@ -151,11 +158,7 @@ def cache_non_lora_outputs(self):
151158 # explicitly.
152159 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
153160 output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
154- self .cached_non_lora_outputs [scheduler_cls .__name__ ] = output_no_lora
155-
156- # Ensures that there's no inconsistency when reusing the cache.
157- yield
158- self .cached_non_lora_outputs .clear ()
161+ self .cached_base_pipe_outputs [scheduler_cls .__name__ ] = output_no_lora
159162
160163 def get_dummy_components (self , scheduler_cls = None , use_dora = False , lora_alpha = None ):
161164 if self .unet_kwargs and self .transformer_kwargs :
@@ -348,7 +351,7 @@ def test_simple_inference(self):
348351 Tests a simple inference and makes sure it works as expected
349352 """
350353 for scheduler_cls in self .scheduler_classes :
351- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
354+ output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
352355 self .assertTrue (output_no_lora .shape == self .output_shape )
353356
354357 def test_simple_inference_with_text_lora (self ):
@@ -363,7 +366,7 @@ def test_simple_inference_with_text_lora(self):
363366 pipe .set_progress_bar_config (disable = None )
364367 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
365368
366- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
369+ output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
367370 self .assertTrue (output_no_lora .shape == self .output_shape )
368371
369372 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
@@ -446,7 +449,7 @@ def test_low_cpu_mem_usage_with_loading(self):
446449 pipe .set_progress_bar_config (disable = None )
447450 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
448451
449- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
452+ output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
450453 self .assertTrue (output_no_lora .shape == self .output_shape )
451454
452455 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
@@ -502,7 +505,7 @@ def test_simple_inference_with_text_lora_and_scale(self):
502505 pipe .set_progress_bar_config (disable = None )
503506 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
504507
505- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
508+ output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
506509 self .assertTrue (output_no_lora .shape == self .output_shape )
507510
508511 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
@@ -540,7 +543,7 @@ def test_simple_inference_with_text_lora_fused(self):
540543 pipe .set_progress_bar_config (disable = None )
541544 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
542545
543- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
546+ output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
544547 self .assertTrue (output_no_lora .shape == self .output_shape )
545548
546549 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
@@ -572,7 +575,7 @@ def test_simple_inference_with_text_lora_unloaded(self):
572575 pipe .set_progress_bar_config (disable = None )
573576 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
574577
575- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
578+ output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
576579 self .assertTrue (output_no_lora .shape == self .output_shape )
577580
578581 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
@@ -607,7 +610,7 @@ def test_simple_inference_with_text_lora_save_load(self):
607610 pipe .set_progress_bar_config (disable = None )
608611 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
609612
610- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
613+ output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
611614 self .assertTrue (output_no_lora .shape == self .output_shape )
612615
613616 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
@@ -658,7 +661,7 @@ def test_simple_inference_with_partial_text_lora(self):
658661 pipe .set_progress_bar_config (disable = None )
659662 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
660663
661- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
664+ output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
662665 self .assertTrue (output_no_lora .shape == self .output_shape )
663666
664667 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
@@ -709,7 +712,7 @@ def test_simple_inference_save_pretrained_with_text_lora(self):
709712 pipe .set_progress_bar_config (disable = None )
710713 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
711714
712- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
715+ output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
713716 self .assertTrue (output_no_lora .shape == self .output_shape )
714717
715718 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
@@ -752,7 +755,7 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
752755 pipe .set_progress_bar_config (disable = None )
753756 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
754757
755- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
758+ output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
756759 self .assertTrue (output_no_lora .shape == self .output_shape )
757760
758761 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
@@ -793,7 +796,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
793796 pipe .set_progress_bar_config (disable = None )
794797 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
795798
796- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
799+ output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
797800 self .assertTrue (output_no_lora .shape == self .output_shape )
798801
799802 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
@@ -837,7 +840,7 @@ def test_simple_inference_with_text_lora_denoiser_fused(self):
837840 pipe .set_progress_bar_config (disable = None )
838841 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
839842
840- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
843+ output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
841844 self .assertTrue (output_no_lora .shape == self .output_shape )
842845
843846 pipe , denoiser = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
@@ -875,7 +878,7 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self):
875878 pipe .set_progress_bar_config (disable = None )
876879 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
877880
878- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
881+ output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
879882 self .assertTrue (output_no_lora .shape == self .output_shape )
880883
881884 pipe , denoiser = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
@@ -954,7 +957,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self):
954957 pipe .set_progress_bar_config (disable = None )
955958 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
956959
957- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
960+ output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
958961
959962 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
960963 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1083,7 +1086,7 @@ def test_simple_inference_with_text_denoiser_block_scale(self):
10831086 pipe .set_progress_bar_config (disable = None )
10841087 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
10851088
1086- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
1089+ output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
10871090
10881091 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
10891092 self .assertTrue (check_if_lora_correctly_set (pipe .text_encoder ), "Lora not correctly set in text encoder" )
@@ -1140,7 +1143,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
11401143 pipe .set_progress_bar_config (disable = None )
11411144 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
11421145
1143- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
1146+ output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
11441147
11451148 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
11461149 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1303,7 +1306,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self):
13031306 pipe .set_progress_bar_config (disable = None )
13041307 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
13051308
1306- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
1309+ output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
13071310
13081311 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
13091312 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1397,7 +1400,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
13971400 pipe .set_progress_bar_config (disable = None )
13981401 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
13991402
1400- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
1403+ output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
14011404
14021405 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
14031406 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1641,7 +1644,7 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(
16411644 pipe .set_progress_bar_config (disable = None )
16421645 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
16431646
1644- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
1647+ output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
16451648 self .assertTrue (output_no_lora .shape == self .output_shape )
16461649
16471650 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
@@ -1722,7 +1725,7 @@ def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expec
17221725 pipe .set_progress_bar_config (disable = None )
17231726 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
17241727
1725- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
1728+ output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
17261729 self .assertTrue (output_no_lora .shape == self .output_shape )
17271730
17281731 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
@@ -1777,7 +1780,7 @@ def test_simple_inference_with_dora(self):
17771780 pipe .set_progress_bar_config (disable = None )
17781781 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
17791782
1780- output_no_dora_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
1783+ output_no_dora_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
17811784 self .assertTrue (output_no_dora_lora .shape == self .output_shape )
17821785
17831786 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
@@ -1909,7 +1912,7 @@ def test_logs_info_when_no_lora_keys_found(self):
19091912 pipe .set_progress_bar_config (disable = None )
19101913
19111914 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
1912- original_out = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
1915+ original_out = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
19131916
19141917 no_op_state_dict = {"lora_foo" : torch .tensor (2.0 ), "lora_bar" : torch .tensor (3.0 )}
19151918 logger = logging .get_logger ("diffusers.loaders.peft" )
@@ -1955,7 +1958,7 @@ def test_set_adapters_match_attention_kwargs(self):
19551958 pipe .set_progress_bar_config (disable = None )
19561959 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
19571960
1958- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
1961+ output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
19591962 self .assertTrue (output_no_lora .shape == self .output_shape )
19601963
19611964 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
@@ -2309,7 +2312,7 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
23092312 pipe = self .pipeline_class (** components ).to (torch_device )
23102313 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
23112314
2312- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
2315+ output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
23132316 self .assertTrue (output_no_lora .shape == self .output_shape )
23142317
23152318 pipe , _ = self .add_adapters_to_pipeline (
@@ -2359,7 +2362,7 @@ def test_inference_load_delete_load_adapters(self):
23592362 pipe .set_progress_bar_config (disable = None )
23602363 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
23612364
2362- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
2365+ output_no_lora = self .cached_base_pipe_outputs [scheduler_cls .__name__ ]
23632366
23642367 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
23652368 pipe .text_encoder .add_adapter (text_lora_config )
0 commit comments