@@ -158,6 +158,14 @@ def get_base_pipe_outs(self):
158158 cached_base_pipe_outs .update ({scheduler_cls .__name__ : output_no_lora })
159159
160160 setattr (type (self ), "cached_base_pipe_outs" , cached_base_pipe_outs )
161+
162+ def get_base_pipeline_output (self , scheduler_cls ):
163+ """
164+ Returns the cached base pipeline output for the given scheduler.
165+ Properly handles accessing the class-level cache.
166+ """
167+ cached_base_pipe_outs = getattr (type (self ), "cached_base_pipe_outs" , {})
168+ return cached_base_pipe_outs [scheduler_cls .__name__ ]
161169
162170 def get_dummy_components (self , scheduler_cls = None , use_dora = False , lora_alpha = None ):
163171 if self .unet_kwargs and self .transformer_kwargs :
@@ -350,7 +358,7 @@ def test_simple_inference(self):
350358 Tests a simple inference and makes sure it works as expected
351359 """
352360 for scheduler_cls in self .scheduler_classes :
353- output_no_lora = self .cached_base_pipe_outs [ scheduler_cls . __name__ ]
361+ output_no_lora = self .get_base_pipeline_output ( scheduler_cls )
354362 self .assertTrue (output_no_lora .shape == self .output_shape )
355363
356364 def test_simple_inference_with_text_lora (self ):
@@ -365,7 +373,7 @@ def test_simple_inference_with_text_lora(self):
365373 pipe .set_progress_bar_config (disable = None )
366374 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
367375
368- output_no_lora = self .cached_base_pipe_outs [ scheduler_cls . __name__ ]
376+ output_no_lora = self .get_base_pipeline_output ( scheduler_cls )
369377 self .assertTrue (output_no_lora .shape == self .output_shape )
370378
371379 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
@@ -448,7 +456,7 @@ def test_low_cpu_mem_usage_with_loading(self):
448456 pipe .set_progress_bar_config (disable = None )
449457 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
450458
451- output_no_lora = self .cached_base_pipe_outs [ scheduler_cls . __name__ ]
459+ output_no_lora = self .get_base_pipeline_output ( scheduler_cls )
452460 self .assertTrue (output_no_lora .shape == self .output_shape )
453461
454462 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
@@ -504,7 +512,7 @@ def test_simple_inference_with_text_lora_and_scale(self):
504512 pipe .set_progress_bar_config (disable = None )
505513 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
506514
507- output_no_lora = self .cached_base_pipe_outs [ scheduler_cls . __name__ ]
515+ output_no_lora = self .get_base_pipeline_output ( scheduler_cls )
508516 self .assertTrue (output_no_lora .shape == self .output_shape )
509517
510518 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
@@ -542,7 +550,7 @@ def test_simple_inference_with_text_lora_fused(self):
542550 pipe .set_progress_bar_config (disable = None )
543551 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
544552
545- output_no_lora = self .cached_base_pipe_outs [ scheduler_cls . __name__ ]
553+ output_no_lora = self .get_base_pipeline_output ( scheduler_cls )
546554 self .assertTrue (output_no_lora .shape == self .output_shape )
547555
548556 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
@@ -574,7 +582,7 @@ def test_simple_inference_with_text_lora_unloaded(self):
574582 pipe .set_progress_bar_config (disable = None )
575583 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
576584
577- output_no_lora = self .cached_base_pipe_outs [ scheduler_cls . __name__ ]
585+ output_no_lora = self .get_base_pipeline_output ( scheduler_cls )
578586 self .assertTrue (output_no_lora .shape == self .output_shape )
579587
580588 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
@@ -609,7 +617,7 @@ def test_simple_inference_with_text_lora_save_load(self):
609617 pipe .set_progress_bar_config (disable = None )
610618 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
611619
612- output_no_lora = self .cached_base_pipe_outs [ scheduler_cls . __name__ ]
620+ output_no_lora = self .get_base_pipeline_output ( scheduler_cls )
613621 self .assertTrue (output_no_lora .shape == self .output_shape )
614622
615623 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
@@ -660,7 +668,7 @@ def test_simple_inference_with_partial_text_lora(self):
660668 pipe .set_progress_bar_config (disable = None )
661669 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
662670
663- output_no_lora = self .cached_base_pipe_outs [ scheduler_cls . __name__ ]
671+ output_no_lora = self .get_base_pipeline_output ( scheduler_cls )
664672 self .assertTrue (output_no_lora .shape == self .output_shape )
665673
666674 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
@@ -711,7 +719,7 @@ def test_simple_inference_save_pretrained_with_text_lora(self):
711719 pipe .set_progress_bar_config (disable = None )
712720 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
713721
714- output_no_lora = self .cached_base_pipe_outs [ scheduler_cls . __name__ ]
722+ output_no_lora = self .get_base_pipeline_output ( scheduler_cls )
715723 self .assertTrue (output_no_lora .shape == self .output_shape )
716724
717725 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
@@ -754,7 +762,7 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
754762 pipe .set_progress_bar_config (disable = None )
755763 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
756764
757- output_no_lora = self .cached_base_pipe_outs [ scheduler_cls . __name__ ]
765+ output_no_lora = self .get_base_pipeline_output ( scheduler_cls )
758766 self .assertTrue (output_no_lora .shape == self .output_shape )
759767
760768 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
@@ -795,7 +803,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
795803 pipe .set_progress_bar_config (disable = None )
796804 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
797805
798- output_no_lora = self .cached_base_pipe_outs [ scheduler_cls . __name__ ]
806+ output_no_lora = self .get_base_pipeline_output ( scheduler_cls )
799807 self .assertTrue (output_no_lora .shape == self .output_shape )
800808
801809 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
@@ -839,7 +847,7 @@ def test_simple_inference_with_text_lora_denoiser_fused(self):
839847 pipe .set_progress_bar_config (disable = None )
840848 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
841849
842- output_no_lora = self .cached_base_pipe_outs [ scheduler_cls . __name__ ]
850+ output_no_lora = self .get_base_pipeline_output ( scheduler_cls )
843851 self .assertTrue (output_no_lora .shape == self .output_shape )
844852
845853 pipe , denoiser = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
@@ -877,7 +885,7 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self):
877885 pipe .set_progress_bar_config (disable = None )
878886 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
879887
880- output_no_lora = self .cached_base_pipe_outs [ scheduler_cls . __name__ ]
888+ output_no_lora = self .get_base_pipeline_output ( scheduler_cls )
881889 self .assertTrue (output_no_lora .shape == self .output_shape )
882890
883891 pipe , denoiser = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
@@ -956,7 +964,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self):
956964 pipe .set_progress_bar_config (disable = None )
957965 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
958966
959- output_no_lora = self .cached_base_pipe_outs [ scheduler_cls . __name__ ]
967+ output_no_lora = self .get_base_pipeline_output ( scheduler_cls )
960968
961969 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
962970 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1085,7 +1093,7 @@ def test_simple_inference_with_text_denoiser_block_scale(self):
10851093 pipe .set_progress_bar_config (disable = None )
10861094 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
10871095
1088- output_no_lora = self .cached_base_pipe_outs [ scheduler_cls . __name__ ]
1096+ output_no_lora = self .get_base_pipeline_output ( scheduler_cls )
10891097
10901098 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
10911099 self .assertTrue (check_if_lora_correctly_set (pipe .text_encoder ), "Lora not correctly set in text encoder" )
@@ -1142,7 +1150,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
11421150 pipe .set_progress_bar_config (disable = None )
11431151 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
11441152
1145- output_no_lora = self .cached_base_pipe_outs [ scheduler_cls . __name__ ]
1153+ output_no_lora = self .get_base_pipeline_output ( scheduler_cls )
11461154
11471155 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
11481156 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1305,7 +1313,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self):
13051313 pipe .set_progress_bar_config (disable = None )
13061314 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
13071315
1308- output_no_lora = self .cached_base_pipe_outs [ scheduler_cls . __name__ ]
1316+ output_no_lora = self .get_base_pipeline_output ( scheduler_cls )
13091317
13101318 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
13111319 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1399,7 +1407,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
13991407 pipe .set_progress_bar_config (disable = None )
14001408 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
14011409
1402- output_no_lora = self .cached_base_pipe_outs [ scheduler_cls . __name__ ]
1410+ output_no_lora = self .get_base_pipeline_output ( scheduler_cls )
14031411
14041412 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
14051413 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1643,7 +1651,7 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(
16431651 pipe .set_progress_bar_config (disable = None )
16441652 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
16451653
1646- output_no_lora = self .cached_base_pipe_outs [ scheduler_cls . __name__ ]
1654+ output_no_lora = self .get_base_pipeline_output ( scheduler_cls )
16471655 self .assertTrue (output_no_lora .shape == self .output_shape )
16481656
16491657 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
@@ -1724,7 +1732,7 @@ def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expec
17241732 pipe .set_progress_bar_config (disable = None )
17251733 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
17261734
1727- output_no_lora = self .cached_base_pipe_outs [ scheduler_cls . __name__ ]
1735+ output_no_lora = self .get_base_pipeline_output ( scheduler_cls )
17281736 self .assertTrue (output_no_lora .shape == self .output_shape )
17291737
17301738 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
@@ -1779,7 +1787,7 @@ def test_simple_inference_with_dora(self):
17791787 pipe .set_progress_bar_config (disable = None )
17801788 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
17811789
1782- output_no_dora_lora = self .cached_base_pipe_outs [ scheduler_cls . __name__ ]
1790+ output_no_dora_lora = self .get_base_pipeline_output ( scheduler_cls )
17831791 self .assertTrue (output_no_dora_lora .shape == self .output_shape )
17841792
17851793 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
@@ -1911,7 +1919,7 @@ def test_logs_info_when_no_lora_keys_found(self):
19111919 pipe .set_progress_bar_config (disable = None )
19121920
19131921 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
1914- original_out = self .cached_base_pipe_outs [ scheduler_cls . __name__ ]
1922+ original_out = self .get_base_pipeline_output ( scheduler_cls )
19151923
19161924 no_op_state_dict = {"lora_foo" : torch .tensor (2.0 ), "lora_bar" : torch .tensor (3.0 )}
19171925 logger = logging .get_logger ("diffusers.loaders.peft" )
@@ -1957,7 +1965,7 @@ def test_set_adapters_match_attention_kwargs(self):
19571965 pipe .set_progress_bar_config (disable = None )
19581966 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
19591967
1960- output_no_lora = self .cached_base_pipe_outs [ scheduler_cls . __name__ ]
1968+ output_no_lora = self .get_base_pipeline_output ( scheduler_cls )
19611969 self .assertTrue (output_no_lora .shape == self .output_shape )
19621970
19631971 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
@@ -2311,7 +2319,7 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
23112319 pipe = self .pipeline_class (** components ).to (torch_device )
23122320 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
23132321
2314- output_no_lora = self .cached_base_pipe_outs [ scheduler_cls . __name__ ]
2322+ output_no_lora = self .get_base_pipeline_output ( scheduler_cls )
23152323 self .assertTrue (output_no_lora .shape == self .output_shape )
23162324
23172325 pipe , _ = self .add_adapters_to_pipeline (
@@ -2361,7 +2369,7 @@ def test_inference_load_delete_load_adapters(self):
23612369 pipe .set_progress_bar_config (disable = None )
23622370 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
23632371
2364- output_no_lora = self .cached_base_pipe_outs [ scheduler_cls . __name__ ]
2372+ output_no_lora = self .get_base_pipeline_output ( scheduler_cls )
23652373
23662374 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
23672375 pipe .text_encoder .add_adapter (text_lora_config )
0 commit comments