@@ -133,16 +133,7 @@ def get_base_pipeline_outputs(self):
133133 """
134134 This fixture is executed once per test class and caches the baseline outputs.
135135 """
136- components , _ , _ = self .get_dummy_components (self .scheduler_cls )
137- pipe = self .pipeline_class (** components )
138- pipe = pipe .to (torch_device )
139- pipe .set_progress_bar_config (disable = None )
140-
141- # Always ensure the inputs are without the `generator`. Make sure to pass the `generator`
142- # explicitly.
143- _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
144- output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
145- self .cached_non_lora_output = output_no_lora
136+ self .cached_non_lora_output = self ._compute_baseline_output ()
146137
147138 # Ensures that there's no inconsistency when reusing the cache.
148139 yield
@@ -260,12 +251,23 @@ def get_dummy_inputs(self, with_generator=True):
260251
261252 return noise , input_ids , pipeline_inputs
262253
263- def get_cached_non_lora_output (self ):
254+ def get_base_pipe_output (self ):
264255 """Return the cached baseline output produced without any LoRA adapters."""
265256 if self .cached_non_lora_output is None :
266- raise ValueError ( "The baseline output cache is empty. Ensure the fixture has been executed." )
257+ self . cached_non_lora_output = self . _compute_baseline_output ( )
267258 return self .cached_non_lora_output
268259
260+ def _compute_baseline_output (self ):
261+ components , _ , _ = self .get_dummy_components (self .scheduler_cls )
262+ pipe = self .pipeline_class (** components )
263+ pipe = pipe .to (torch_device )
264+ pipe .set_progress_bar_config (disable = None )
265+
266+ # Always ensure the inputs are without the `generator`. Make sure to pass the `generator`
267+ # explicitly.
268+ _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
269+ return pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
270+
269271 # Copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
270272 def get_dummy_tokens (self ):
271273 max_seq_length = 77
@@ -344,7 +346,7 @@ def test_simple_inference(self):
344346 """
345347 Tests a simple inference and makes sure it works as expected
346348 """
347- output_no_lora = self .get_cached_non_lora_output ()
349+ output_no_lora = self .get_base_pipe_output ()
348350 self .assertTrue (output_no_lora .shape == self .output_shape )
349351
350352 def test_simple_inference_with_text_lora (self ):
@@ -358,7 +360,7 @@ def test_simple_inference_with_text_lora(self):
358360 pipe .set_progress_bar_config (disable = None )
359361 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
360362
361- output_no_lora = self .get_cached_non_lora_output ()
363+ output_no_lora = self .get_base_pipe_output ()
362364 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
363365
364366 output_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
@@ -483,7 +485,7 @@ def test_simple_inference_with_text_lora_and_scale(self):
483485 pipe .set_progress_bar_config (disable = None )
484486 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
485487
486- output_no_lora = self .get_cached_non_lora_output ()
488+ output_no_lora = self .get_base_pipe_output ()
487489
488490 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
489491
@@ -519,7 +521,7 @@ def test_simple_inference_with_text_lora_fused(self):
519521 pipe .set_progress_bar_config (disable = None )
520522 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
521523
522- output_no_lora = self .get_cached_non_lora_output ()
524+ output_no_lora = self .get_base_pipe_output ()
523525
524526 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
525527
@@ -549,7 +551,7 @@ def test_simple_inference_with_text_lora_unloaded(self):
549551 pipe .set_progress_bar_config (disable = None )
550552 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
551553
552- output_no_lora = self .get_cached_non_lora_output ()
554+ output_no_lora = self .get_base_pipe_output ()
553555
554556 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
555557
@@ -627,7 +629,7 @@ def test_simple_inference_with_partial_text_lora(self):
627629 pipe .set_progress_bar_config (disable = None )
628630 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
629631
630- output_no_lora = self .get_cached_non_lora_output ()
632+ output_no_lora = self .get_base_pipe_output ()
631633
632634 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
633635
@@ -751,7 +753,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
751753 pipe .set_progress_bar_config (disable = None )
752754 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
753755
754- output_no_lora = self .get_cached_non_lora_output ()
756+ output_no_lora = self .get_base_pipe_output ()
755757 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
756758
757759 output_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
@@ -792,7 +794,7 @@ def test_simple_inference_with_text_lora_denoiser_fused(self):
792794 pipe .set_progress_bar_config (disable = None )
793795 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
794796
795- output_no_lora = self .get_cached_non_lora_output ()
797+ output_no_lora = self .get_base_pipe_output ()
796798
797799 pipe , denoiser = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
798800
@@ -826,7 +828,7 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self):
826828 pipe .set_progress_bar_config (disable = None )
827829 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
828830
829- output_no_lora = self .get_cached_non_lora_output ()
831+ output_no_lora = self .get_base_pipe_output ()
830832
831833 pipe , denoiser = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
832834
@@ -900,7 +902,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self):
900902 pipe .set_progress_bar_config (disable = None )
901903 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
902904
903- output_no_lora = self .get_cached_non_lora_output ()
905+ output_no_lora = self .get_base_pipe_output ()
904906
905907 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
906908 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1024,7 +1026,7 @@ def test_simple_inference_with_text_denoiser_block_scale(self):
10241026 pipe .set_progress_bar_config (disable = None )
10251027 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
10261028
1027- output_no_lora = self .get_cached_non_lora_output ()
1029+ output_no_lora = self .get_base_pipe_output ()
10281030
10291031 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
10301032 self .assertTrue (check_if_lora_correctly_set (pipe .text_encoder ), "Lora not correctly set in text encoder" )
@@ -1080,7 +1082,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
10801082 pipe .set_progress_bar_config (disable = None )
10811083 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
10821084
1083- output_no_lora = self .get_cached_non_lora_output ()
1085+ output_no_lora = self .get_base_pipe_output ()
10841086
10851087 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
10861088 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1240,7 +1242,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self):
12401242 pipe .set_progress_bar_config (disable = None )
12411243 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
12421244
1243- output_no_lora = self .get_cached_non_lora_output ()
1245+ output_no_lora = self .get_base_pipe_output ()
12441246
12451247 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
12461248 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1331,7 +1333,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
13311333 pipe .set_progress_bar_config (disable = None )
13321334 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
13331335
1334- output_no_lora = self .get_cached_non_lora_output ()
1336+ output_no_lora = self .get_base_pipe_output ()
13351337
13361338 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
13371339 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1637,7 +1639,7 @@ def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expec
16371639 pipe .set_progress_bar_config (disable = None )
16381640 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
16391641
1640- output_no_lora = self .get_cached_non_lora_output ()
1642+ output_no_lora = self .get_base_pipe_output ()
16411643
16421644 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
16431645 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1812,7 +1814,7 @@ def test_logs_info_when_no_lora_keys_found(self):
18121814 pipe .set_progress_bar_config (disable = None )
18131815
18141816 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
1815- output_no_lora = self .get_cached_non_lora_output ()
1817+ output_no_lora = self .get_base_pipe_output ()
18161818
18171819 no_op_state_dict = {"lora_foo" : torch .tensor (2.0 ), "lora_bar" : torch .tensor (3.0 )}
18181820 logger = logging .get_logger ("diffusers.loaders.peft" )
@@ -1856,7 +1858,7 @@ def test_set_adapters_match_attention_kwargs(self):
18561858 pipe .set_progress_bar_config (disable = None )
18571859 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
18581860
1859- output_no_lora = self .get_cached_non_lora_output ()
1861+ output_no_lora = self .get_base_pipe_output ()
18601862 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
18611863
18621864 lora_scale = 0.5
@@ -2247,7 +2249,7 @@ def test_inference_load_delete_load_adapters(self):
22472249 pipe .set_progress_bar_config (disable = None )
22482250 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
22492251
2250- output_no_lora = self .get_cached_non_lora_output ()
2252+ output_no_lora = self .get_base_pipe_output ()
22512253
22522254 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
22532255 pipe .text_encoder .add_adapter (text_lora_config )
0 commit comments