@@ -129,17 +129,21 @@ class PeftLoraLoaderMixinTests:
129129 text_encoder_target_modules = ["q_proj" , "k_proj" , "v_proj" , "out_proj" ]
130130 denoiser_target_modules = ["to_q" , "to_k" , "to_v" , "to_out.0" ]
131131
132- cached_non_lora_outputs = {}
132+ cached_base_pipe_outs = {}
133133
134- @pytest .fixture (scope = "class" , autouse = True )
135- def cache_non_lora_outputs (self ):
136- """
137- This fixture will be executed once per test class and will populate
138- the cached_non_lora_outputs dictionary.
139- """
134+ def setUp (self ):
135+ self .get_base_pipe_outs ()
136+ super ().setUp ()
137+
138+ def get_base_pipe_outs (self ):
139+ cached_base_pipe_outs = getattr (type (self ), "cached_base_pipe_outs" , {})
140+ all_scheduler_names = [scheduler_cls .__name__ for scheduler_cls in self .scheduler_classes ]
141+ if cached_base_pipe_outs is not None and all (k in cached_base_pipe_outs for k in all_scheduler_names ):
142+ return
143+
144+ cached_base_pipe_outs = cached_base_pipe_outs or {}
140145 for scheduler_cls in self .scheduler_classes :
141- # Check if the output for this scheduler is already cached to avoid re-running
142- if scheduler_cls .__name__ in self .cached_non_lora_outputs :
146+ if scheduler_cls .__name__ in cached_base_pipe_outs :
143147 continue
144148
145149 components , _ , _ = self .get_dummy_components (scheduler_cls )
@@ -151,11 +155,9 @@ def cache_non_lora_outputs(self):
151155 # explicitly.
152156 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
153157 output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
154- self . cached_non_lora_outputs [ scheduler_cls .__name__ ] = output_no_lora
158+ cached_base_pipe_outs . update ({ scheduler_cls .__name__ : output_no_lora })
155159
156- # Ensures that there's no inconsistency when reusing the cache.
157- yield
158- self .cached_non_lora_outputs .clear ()
160+ setattr (type (self ), "cached_base_pipe_outs" , cached_base_pipe_outs )
159161
160162 def get_dummy_components (self , scheduler_cls = None , use_dora = False , lora_alpha = None ):
161163 if self .unet_kwargs and self .transformer_kwargs :
@@ -348,7 +350,7 @@ def test_simple_inference(self):
348350 Tests a simple inference and makes sure it works as expected
349351 """
350352 for scheduler_cls in self .scheduler_classes :
351- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
353+ output_no_lora = self .cached_base_pipe_outs [scheduler_cls .__name__ ]
352354 self .assertTrue (output_no_lora .shape == self .output_shape )
353355
354356 def test_simple_inference_with_text_lora (self ):
@@ -363,7 +365,7 @@ def test_simple_inference_with_text_lora(self):
363365 pipe .set_progress_bar_config (disable = None )
364366 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
365367
366- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
368+ output_no_lora = self .cached_base_pipe_outs [scheduler_cls .__name__ ]
367369 self .assertTrue (output_no_lora .shape == self .output_shape )
368370
369371 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
@@ -446,7 +448,7 @@ def test_low_cpu_mem_usage_with_loading(self):
446448 pipe .set_progress_bar_config (disable = None )
447449 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
448450
449- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
451+ output_no_lora = self .cached_base_pipe_outs [scheduler_cls .__name__ ]
450452 self .assertTrue (output_no_lora .shape == self .output_shape )
451453
452454 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
@@ -502,7 +504,7 @@ def test_simple_inference_with_text_lora_and_scale(self):
502504 pipe .set_progress_bar_config (disable = None )
503505 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
504506
505- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
507+ output_no_lora = self .cached_base_pipe_outs [scheduler_cls .__name__ ]
506508 self .assertTrue (output_no_lora .shape == self .output_shape )
507509
508510 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
@@ -540,7 +542,7 @@ def test_simple_inference_with_text_lora_fused(self):
540542 pipe .set_progress_bar_config (disable = None )
541543 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
542544
543- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
545+ output_no_lora = self .cached_base_pipe_outs [scheduler_cls .__name__ ]
544546 self .assertTrue (output_no_lora .shape == self .output_shape )
545547
546548 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
@@ -572,7 +574,7 @@ def test_simple_inference_with_text_lora_unloaded(self):
572574 pipe .set_progress_bar_config (disable = None )
573575 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
574576
575- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
577+ output_no_lora = self .cached_base_pipe_outs [scheduler_cls .__name__ ]
576578 self .assertTrue (output_no_lora .shape == self .output_shape )
577579
578580 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
@@ -607,7 +609,7 @@ def test_simple_inference_with_text_lora_save_load(self):
607609 pipe .set_progress_bar_config (disable = None )
608610 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
609611
610- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
612+ output_no_lora = self .cached_base_pipe_outs [scheduler_cls .__name__ ]
611613 self .assertTrue (output_no_lora .shape == self .output_shape )
612614
613615 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
@@ -658,7 +660,7 @@ def test_simple_inference_with_partial_text_lora(self):
658660 pipe .set_progress_bar_config (disable = None )
659661 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
660662
661- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
663+ output_no_lora = self .cached_base_pipe_outs [scheduler_cls .__name__ ]
662664 self .assertTrue (output_no_lora .shape == self .output_shape )
663665
664666 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
@@ -709,7 +711,7 @@ def test_simple_inference_save_pretrained_with_text_lora(self):
709711 pipe .set_progress_bar_config (disable = None )
710712 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
711713
712- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
714+ output_no_lora = self .cached_base_pipe_outs [scheduler_cls .__name__ ]
713715 self .assertTrue (output_no_lora .shape == self .output_shape )
714716
715717 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
@@ -752,7 +754,7 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
752754 pipe .set_progress_bar_config (disable = None )
753755 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
754756
755- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
757+ output_no_lora = self .cached_base_pipe_outs [scheduler_cls .__name__ ]
756758 self .assertTrue (output_no_lora .shape == self .output_shape )
757759
758760 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
@@ -793,7 +795,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
793795 pipe .set_progress_bar_config (disable = None )
794796 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
795797
796- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
798+ output_no_lora = self .cached_base_pipe_outs [scheduler_cls .__name__ ]
797799 self .assertTrue (output_no_lora .shape == self .output_shape )
798800
799801 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
@@ -837,7 +839,7 @@ def test_simple_inference_with_text_lora_denoiser_fused(self):
837839 pipe .set_progress_bar_config (disable = None )
838840 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
839841
840- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
842+ output_no_lora = self .cached_base_pipe_outs [scheduler_cls .__name__ ]
841843 self .assertTrue (output_no_lora .shape == self .output_shape )
842844
843845 pipe , denoiser = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
@@ -875,7 +877,7 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self):
875877 pipe .set_progress_bar_config (disable = None )
876878 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
877879
878- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
880+ output_no_lora = self .cached_base_pipe_outs [scheduler_cls .__name__ ]
879881 self .assertTrue (output_no_lora .shape == self .output_shape )
880882
881883 pipe , denoiser = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
@@ -954,7 +956,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self):
954956 pipe .set_progress_bar_config (disable = None )
955957 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
956958
957- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
959+ output_no_lora = self .cached_base_pipe_outs [scheduler_cls .__name__ ]
958960
959961 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
960962 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1083,7 +1085,7 @@ def test_simple_inference_with_text_denoiser_block_scale(self):
10831085 pipe .set_progress_bar_config (disable = None )
10841086 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
10851087
1086- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
1088+ output_no_lora = self .cached_base_pipe_outs [scheduler_cls .__name__ ]
10871089
10881090 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
10891091 self .assertTrue (check_if_lora_correctly_set (pipe .text_encoder ), "Lora not correctly set in text encoder" )
@@ -1140,7 +1142,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
11401142 pipe .set_progress_bar_config (disable = None )
11411143 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
11421144
1143- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
1145+ output_no_lora = self .cached_base_pipe_outs [scheduler_cls .__name__ ]
11441146
11451147 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
11461148 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1303,7 +1305,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self):
13031305 pipe .set_progress_bar_config (disable = None )
13041306 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
13051307
1306- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
1308+ output_no_lora = self .cached_base_pipe_outs [scheduler_cls .__name__ ]
13071309
13081310 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
13091311 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1397,7 +1399,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
13971399 pipe .set_progress_bar_config (disable = None )
13981400 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
13991401
1400- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
1402+ output_no_lora = self .cached_base_pipe_outs [scheduler_cls .__name__ ]
14011403
14021404 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
14031405 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1641,7 +1643,7 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(
16411643 pipe .set_progress_bar_config (disable = None )
16421644 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
16431645
1644- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
1646+ output_no_lora = self .cached_base_pipe_outs [scheduler_cls .__name__ ]
16451647 self .assertTrue (output_no_lora .shape == self .output_shape )
16461648
16471649 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
@@ -1722,7 +1724,7 @@ def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expec
17221724 pipe .set_progress_bar_config (disable = None )
17231725 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
17241726
1725- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
1727+ output_no_lora = self .cached_base_pipe_outs [scheduler_cls .__name__ ]
17261728 self .assertTrue (output_no_lora .shape == self .output_shape )
17271729
17281730 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
@@ -1777,7 +1779,7 @@ def test_simple_inference_with_dora(self):
17771779 pipe .set_progress_bar_config (disable = None )
17781780 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
17791781
1780- output_no_dora_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
1782+ output_no_dora_lora = self .cached_base_pipe_outs [scheduler_cls .__name__ ]
17811783 self .assertTrue (output_no_dora_lora .shape == self .output_shape )
17821784
17831785 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
@@ -1909,7 +1911,7 @@ def test_logs_info_when_no_lora_keys_found(self):
19091911 pipe .set_progress_bar_config (disable = None )
19101912
19111913 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
1912- original_out = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
1914+ original_out = self .cached_base_pipe_outs [scheduler_cls .__name__ ]
19131915
19141916 no_op_state_dict = {"lora_foo" : torch .tensor (2.0 ), "lora_bar" : torch .tensor (3.0 )}
19151917 logger = logging .get_logger ("diffusers.loaders.peft" )
@@ -1955,7 +1957,7 @@ def test_set_adapters_match_attention_kwargs(self):
19551957 pipe .set_progress_bar_config (disable = None )
19561958 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
19571959
1958- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
1960+ output_no_lora = self .cached_base_pipe_outs [scheduler_cls .__name__ ]
19591961 self .assertTrue (output_no_lora .shape == self .output_shape )
19601962
19611963 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
@@ -2309,7 +2311,7 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
23092311 pipe = self .pipeline_class (** components ).to (torch_device )
23102312 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
23112313
2312- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
2314+ output_no_lora = self .cached_base_pipe_outs [scheduler_cls .__name__ ]
23132315 self .assertTrue (output_no_lora .shape == self .output_shape )
23142316
23152317 pipe , _ = self .add_adapters_to_pipeline (
@@ -2359,7 +2361,7 @@ def test_inference_load_delete_load_adapters(self):
23592361 pipe .set_progress_bar_config (disable = None )
23602362 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
23612363
2362- output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
2364+ output_no_lora = self .cached_base_pipe_outs [scheduler_cls .__name__ ]
23632365
23642366 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
23652367 pipe .text_encoder .add_adapter (text_lora_config )
0 commit comments