@@ -126,13 +126,12 @@ class PeftLoraLoaderMixinTests:
126126    text_encoder_target_modules  =  ["q_proj" , "k_proj" , "v_proj" , "out_proj" ]
127127    denoiser_target_modules  =  ["to_q" , "to_k" , "to_v" , "to_out.0" ]
128128
129-     cached_non_lora_outputs  =  {} 
129+     cached_non_lora_output  =  None 
130130
131131    @pytest .fixture (scope = "class" , autouse = True ) 
132132    def  get_base_pipeline_outputs (self ):
133133        """ 
134-         This fixture will be executed once per test class and will populate 
135-         the cached_non_lora_outputs dictionary. 
134+         This fixture is executed once per test class and caches the baseline outputs. 
136135        """ 
137136        components , _ , _  =  self .get_dummy_components (self .scheduler_cls )
138137        pipe  =  self .pipeline_class (** components )
@@ -143,11 +142,11 @@ def get_base_pipeline_outputs(self):
143142        # explicitly. 
144143        _ , _ , inputs  =  self .get_dummy_inputs (with_generator = False )
145144        output_no_lora  =  pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
146-         self .cached_non_lora_outputs [ self . scheduler_cls . __name__ ]  =  output_no_lora 
145+         self .cached_non_lora_output  =  output_no_lora 
147146
148147        # Ensures that there's no inconsistency when reusing the cache. 
149148        yield 
150-         self .cached_non_lora_outputs . clear () 
149+         self .cached_non_lora_output   =   None 
151150
152151    def  get_dummy_components (self , scheduler_cls = None , use_dora = False , lora_alpha = None ):
153152        if  self .unet_kwargs  and  self .transformer_kwargs :
@@ -261,6 +260,12 @@ def get_dummy_inputs(self, with_generator=True):
261260
262261        return  noise , input_ids , pipeline_inputs 
263262
263+     def  get_cached_non_lora_output (self ):
264+         """Return the cached baseline output produced without any LoRA adapters.""" 
265+         if  self .cached_non_lora_output  is  None :
266+             raise  ValueError ("The baseline output cache is empty. Ensure the fixture has been executed." )
267+         return  self .cached_non_lora_output 
268+ 
264269    # Copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb 
265270    def  get_dummy_tokens (self ):
266271        max_seq_length  =  77 
@@ -339,7 +344,7 @@ def test_simple_inference(self):
339344        """ 
340345        Tests a simple inference and makes sure it works as expected 
341346        """ 
342-         output_no_lora  =  self .cached_non_lora_outputs [ self . scheduler_cls . __name__ ] 
347+         output_no_lora  =  self .get_cached_non_lora_output () 
343348        self .assertTrue (output_no_lora .shape  ==  self .output_shape )
344349
345350    def  test_simple_inference_with_text_lora (self ):
@@ -353,7 +358,7 @@ def test_simple_inference_with_text_lora(self):
353358        pipe .set_progress_bar_config (disable = None )
354359        _ , _ , inputs  =  self .get_dummy_inputs (with_generator = False )
355360
356-         output_no_lora  =  self .cached_non_lora_outputs [ self . scheduler_cls . __name__ ] 
361+         output_no_lora  =  self .get_cached_non_lora_output () 
357362        pipe , _  =  self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
358363
359364        output_lora  =  pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
@@ -478,7 +483,7 @@ def test_simple_inference_with_text_lora_and_scale(self):
478483        pipe .set_progress_bar_config (disable = None )
479484        _ , _ , inputs  =  self .get_dummy_inputs (with_generator = False )
480485
481-         output_no_lora  =  self .cached_non_lora_outputs [ self . scheduler_cls . __name__ ] 
486+         output_no_lora  =  self .get_cached_non_lora_output () 
482487
483488        pipe , _  =  self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
484489
@@ -514,7 +519,7 @@ def test_simple_inference_with_text_lora_fused(self):
514519        pipe .set_progress_bar_config (disable = None )
515520        _ , _ , inputs  =  self .get_dummy_inputs (with_generator = False )
516521
517-         output_no_lora  =  self .cached_non_lora_outputs [ self . scheduler_cls . __name__ ] 
522+         output_no_lora  =  self .get_cached_non_lora_output () 
518523
519524        pipe , _  =  self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
520525
@@ -544,7 +549,7 @@ def test_simple_inference_with_text_lora_unloaded(self):
544549        pipe .set_progress_bar_config (disable = None )
545550        _ , _ , inputs  =  self .get_dummy_inputs (with_generator = False )
546551
547-         output_no_lora  =  self .cached_non_lora_outputs [ self . scheduler_cls . __name__ ] 
552+         output_no_lora  =  self .get_cached_non_lora_output () 
548553
549554        pipe , _  =  self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
550555
@@ -622,7 +627,7 @@ def test_simple_inference_with_partial_text_lora(self):
622627        pipe .set_progress_bar_config (disable = None )
623628        _ , _ , inputs  =  self .get_dummy_inputs (with_generator = False )
624629
625-         output_no_lora  =  self .cached_non_lora_outputs [ self . scheduler_cls . __name__ ] 
630+         output_no_lora  =  self .get_cached_non_lora_output () 
626631
627632        pipe , _  =  self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
628633
@@ -746,7 +751,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
746751        pipe .set_progress_bar_config (disable = None )
747752        _ , _ , inputs  =  self .get_dummy_inputs (with_generator = False )
748753
749-         output_no_lora  =  self .cached_non_lora_outputs [ self . scheduler_cls . __name__ ] 
754+         output_no_lora  =  self .get_cached_non_lora_output () 
750755        pipe , _  =  self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
751756
752757        output_lora  =  pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
@@ -787,7 +792,7 @@ def test_simple_inference_with_text_lora_denoiser_fused(self):
787792        pipe .set_progress_bar_config (disable = None )
788793        _ , _ , inputs  =  self .get_dummy_inputs (with_generator = False )
789794
790-         output_no_lora  =  self .cached_non_lora_outputs [ self . scheduler_cls . __name__ ] 
795+         output_no_lora  =  self .get_cached_non_lora_output () 
791796
792797        pipe , denoiser  =  self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
793798
@@ -821,7 +826,7 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self):
821826        pipe .set_progress_bar_config (disable = None )
822827        _ , _ , inputs  =  self .get_dummy_inputs (with_generator = False )
823828
824-         output_no_lora  =  self .cached_non_lora_outputs [ self . scheduler_cls . __name__ ] 
829+         output_no_lora  =  self .get_cached_non_lora_output () 
825830
826831        pipe , denoiser  =  self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
827832
@@ -895,7 +900,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self):
895900        pipe .set_progress_bar_config (disable = None )
896901        _ , _ , inputs  =  self .get_dummy_inputs (with_generator = False )
897902
898-         output_no_lora  =  self .cached_non_lora_outputs [ self . scheduler_cls . __name__ ] 
903+         output_no_lora  =  self .get_cached_non_lora_output () 
899904
900905        if  "text_encoder"  in  self .pipeline_class ._lora_loadable_modules :
901906            pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1019,7 +1024,7 @@ def test_simple_inference_with_text_denoiser_block_scale(self):
10191024        pipe .set_progress_bar_config (disable = None )
10201025        _ , _ , inputs  =  self .get_dummy_inputs (with_generator = False )
10211026
1022-         output_no_lora  =  self .cached_non_lora_outputs [ self . scheduler_cls . __name__ ] 
1027+         output_no_lora  =  self .get_cached_non_lora_output () 
10231028
10241029        pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
10251030        self .assertTrue (check_if_lora_correctly_set (pipe .text_encoder ), "Lora not correctly set in text encoder" )
@@ -1075,7 +1080,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
10751080        pipe .set_progress_bar_config (disable = None )
10761081        _ , _ , inputs  =  self .get_dummy_inputs (with_generator = False )
10771082
1078-         output_no_lora  =  self .cached_non_lora_outputs [ self . scheduler_cls . __name__ ] 
1083+         output_no_lora  =  self .get_cached_non_lora_output () 
10791084
10801085        if  "text_encoder"  in  self .pipeline_class ._lora_loadable_modules :
10811086            pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1235,7 +1240,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self):
12351240        pipe .set_progress_bar_config (disable = None )
12361241        _ , _ , inputs  =  self .get_dummy_inputs (with_generator = False )
12371242
1238-         output_no_lora  =  self .cached_non_lora_outputs [ self . scheduler_cls . __name__ ] 
1243+         output_no_lora  =  self .get_cached_non_lora_output () 
12391244
12401245        if  "text_encoder"  in  self .pipeline_class ._lora_loadable_modules :
12411246            pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1326,7 +1331,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
13261331        pipe .set_progress_bar_config (disable = None )
13271332        _ , _ , inputs  =  self .get_dummy_inputs (with_generator = False )
13281333
1329-         output_no_lora  =  self .cached_non_lora_outputs [ self . scheduler_cls . __name__ ] 
1334+         output_no_lora  =  self .get_cached_non_lora_output () 
13301335
13311336        if  "text_encoder"  in  self .pipeline_class ._lora_loadable_modules :
13321337            pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1632,7 +1637,7 @@ def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expec
16321637            pipe .set_progress_bar_config (disable = None )
16331638            _ , _ , inputs  =  self .get_dummy_inputs (with_generator = False )
16341639
1635-             output_no_lora  =  self .cached_non_lora_outputs [ self . scheduler_cls . __name__ ] 
1640+             output_no_lora  =  self .get_cached_non_lora_output () 
16361641
16371642            if  "text_encoder"  in  self .pipeline_class ._lora_loadable_modules :
16381643                pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1807,7 +1812,7 @@ def test_logs_info_when_no_lora_keys_found(self):
18071812        pipe .set_progress_bar_config (disable = None )
18081813
18091814        _ , _ , inputs  =  self .get_dummy_inputs (with_generator = False )
1810-         output_no_lora  =  self .cached_non_lora_outputs [ self . scheduler_cls . __name__ ] 
1815+         output_no_lora  =  self .get_cached_non_lora_output () 
18111816
18121817        no_op_state_dict  =  {"lora_foo" : torch .tensor (2.0 ), "lora_bar" : torch .tensor (3.0 )}
18131818        logger  =  logging .get_logger ("diffusers.loaders.peft" )
@@ -1851,7 +1856,7 @@ def test_set_adapters_match_attention_kwargs(self):
18511856        pipe .set_progress_bar_config (disable = None )
18521857        _ , _ , inputs  =  self .get_dummy_inputs (with_generator = False )
18531858
1854-         output_no_lora  =  self .cached_non_lora_outputs [ self . scheduler_cls . __name__ ] 
1859+         output_no_lora  =  self .get_cached_non_lora_output () 
18551860        pipe , _  =  self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
18561861
18571862        lora_scale  =  0.5 
@@ -2242,7 +2247,7 @@ def test_inference_load_delete_load_adapters(self):
22422247        pipe .set_progress_bar_config (disable = None )
22432248        _ , _ , inputs  =  self .get_dummy_inputs (with_generator = False )
22442249
2245-         output_no_lora  =  self .cached_non_lora_outputs [ self . scheduler_cls . __name__ ] 
2250+         output_no_lora  =  self .get_cached_non_lora_output () 
22462251
22472252        if  "text_encoder"  in  self .pipeline_class ._lora_loadable_modules :
22482253            pipe .text_encoder .add_adapter (text_lora_config )
0 commit comments