@@ -129,6 +129,30 @@ class PeftLoraLoaderMixinTests:
129129 text_encoder_target_modules = ["q_proj" , "k_proj" , "v_proj" , "out_proj" ]
130130 denoiser_target_modules = ["to_q" , "to_k" , "to_v" , "to_out.0" ]
131131
132+ cached_non_lora_outputs = {}
133+
134+ @pytest .fixture (scope = "class" , autouse = True )
135+ def cache_non_lora_outputs (self , request ):
136+ """
137+ This fixture will be executed once per test class and will populate
138+ the cached_non_lora_outputs dictionary.
139+ """
140+ for scheduler_cls in self .scheduler_classes :
141+ # Check if the output for this scheduler is already cached to avoid re-running
142+ if scheduler_cls .__name__ in self .cached_non_lora_outputs :
143+ continue
144+
145+ components , _ , _ = self .get_dummy_components (scheduler_cls )
146+ pipe = self .pipeline_class (** components )
147+ pipe = pipe .to (torch_device )
148+ pipe .set_progress_bar_config (disable = None )
149+
150+ # Always ensure the inputs are without the `generator`. Make sure to pass the `generator`
151+ # explicitly.
152+ _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
153+ output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
154+ self .cached_non_lora_outputs [scheduler_cls .__name__ ] = output_no_lora
155+
132156 def get_dummy_components (self , scheduler_cls = None , use_dora = False , lora_alpha = None ):
133157 if self .unet_kwargs and self .transformer_kwargs :
134158 raise ValueError ("Both `unet_kwargs` and `transformer_kwargs` cannot be specified." )
@@ -320,13 +344,7 @@ def test_simple_inference(self):
320344 Tests a simple inference and makes sure it works as expected
321345 """
322346 for scheduler_cls in self .scheduler_classes :
323- components , text_lora_config , _ = self .get_dummy_components (scheduler_cls )
324- pipe = self .pipeline_class (** components )
325- pipe = pipe .to (torch_device )
326- pipe .set_progress_bar_config (disable = None )
327-
328- _ , _ , inputs = self .get_dummy_inputs ()
329- output_no_lora = pipe (** inputs )[0 ]
347+ output_no_lora = self .cached_non_lora_outputs [scheduler_cls .__name__ ]
330348 self .assertTrue (output_no_lora .shape == self .output_shape )
331349
332350 def test_simple_inference_with_text_lora (self ):
@@ -341,7 +359,7 @@ def test_simple_inference_with_text_lora(self):
341359 pipe .set_progress_bar_config (disable = None )
342360 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
343361
344- output_no_lora = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
362+ output_no_lora = self . cached_non_lora_outputs [ scheduler_cls . __name__ ]
345363 self .assertTrue (output_no_lora .shape == self .output_shape )
346364
347365 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
@@ -424,7 +442,7 @@ def test_low_cpu_mem_usage_with_loading(self):
424442 pipe .set_progress_bar_config (disable = None )
425443 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
426444
427- output_no_lora = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
445+ output_no_lora = self . cached_non_lora_outputs [ scheduler_cls . __name__ ]
428446 self .assertTrue (output_no_lora .shape == self .output_shape )
429447
430448 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
@@ -480,7 +498,7 @@ def test_simple_inference_with_text_lora_and_scale(self):
480498 pipe .set_progress_bar_config (disable = None )
481499 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
482500
483- output_no_lora = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
501+ output_no_lora = self . cached_non_lora_outputs [ scheduler_cls . __name__ ]
484502 self .assertTrue (output_no_lora .shape == self .output_shape )
485503
486504 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
@@ -518,7 +536,7 @@ def test_simple_inference_with_text_lora_fused(self):
518536 pipe .set_progress_bar_config (disable = None )
519537 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
520538
521- output_no_lora = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
539+ output_no_lora = self . cached_non_lora_outputs [ scheduler_cls . __name__ ]
522540 self .assertTrue (output_no_lora .shape == self .output_shape )
523541
524542 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
@@ -550,7 +568,7 @@ def test_simple_inference_with_text_lora_unloaded(self):
550568 pipe .set_progress_bar_config (disable = None )
551569 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
552570
553- output_no_lora = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
571+ output_no_lora = self . cached_non_lora_outputs [ scheduler_cls . __name__ ]
554572 self .assertTrue (output_no_lora .shape == self .output_shape )
555573
556574 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
@@ -585,7 +603,7 @@ def test_simple_inference_with_text_lora_save_load(self):
585603 pipe .set_progress_bar_config (disable = None )
586604 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
587605
588- output_no_lora = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
606+ output_no_lora = self . cached_non_lora_outputs [ scheduler_cls . __name__ ]
589607 self .assertTrue (output_no_lora .shape == self .output_shape )
590608
591609 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
@@ -636,7 +654,7 @@ def test_simple_inference_with_partial_text_lora(self):
636654 pipe .set_progress_bar_config (disable = None )
637655 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
638656
639- output_no_lora = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
657+ output_no_lora = self . cached_non_lora_outputs [ scheduler_cls . __name__ ]
640658 self .assertTrue (output_no_lora .shape == self .output_shape )
641659
642660 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
@@ -687,7 +705,7 @@ def test_simple_inference_save_pretrained_with_text_lora(self):
687705 pipe .set_progress_bar_config (disable = None )
688706 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
689707
690- output_no_lora = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
708+ output_no_lora = self . cached_non_lora_outputs [ scheduler_cls . __name__ ]
691709 self .assertTrue (output_no_lora .shape == self .output_shape )
692710
693711 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
@@ -730,7 +748,7 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
730748 pipe .set_progress_bar_config (disable = None )
731749 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
732750
733- output_no_lora = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
751+ output_no_lora = self . cached_non_lora_outputs [ scheduler_cls . __name__ ]
734752 self .assertTrue (output_no_lora .shape == self .output_shape )
735753
736754 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
@@ -771,7 +789,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
771789 pipe .set_progress_bar_config (disable = None )
772790 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
773791
774- output_no_lora = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
792+ output_no_lora = self . cached_non_lora_outputs [ scheduler_cls . __name__ ]
775793 self .assertTrue (output_no_lora .shape == self .output_shape )
776794
777795 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
@@ -815,7 +833,7 @@ def test_simple_inference_with_text_lora_denoiser_fused(self):
815833 pipe .set_progress_bar_config (disable = None )
816834 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
817835
818- output_no_lora = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
836+ output_no_lora = self . cached_non_lora_outputs [ scheduler_cls . __name__ ]
819837 self .assertTrue (output_no_lora .shape == self .output_shape )
820838
821839 pipe , denoiser = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
@@ -853,7 +871,7 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self):
853871 pipe .set_progress_bar_config (disable = None )
854872 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
855873
856- output_no_lora = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
874+ output_no_lora = self . cached_non_lora_outputs [ scheduler_cls . __name__ ]
857875 self .assertTrue (output_no_lora .shape == self .output_shape )
858876
859877 pipe , denoiser = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
@@ -932,7 +950,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self):
932950 pipe .set_progress_bar_config (disable = None )
933951 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
934952
935- output_no_lora = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
953+ output_no_lora = self . cached_non_lora_outputs [ scheduler_cls . __name__ ]
936954
937955 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
938956 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1061,7 +1079,7 @@ def test_simple_inference_with_text_denoiser_block_scale(self):
10611079 pipe .set_progress_bar_config (disable = None )
10621080 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
10631081
1064- output_no_lora = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
1082+ output_no_lora = self . cached_non_lora_outputs [ scheduler_cls . __name__ ]
10651083
10661084 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
10671085 self .assertTrue (check_if_lora_correctly_set (pipe .text_encoder ), "Lora not correctly set in text encoder" )
@@ -1118,7 +1136,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
11181136 pipe .set_progress_bar_config (disable = None )
11191137 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
11201138
1121- output_no_lora = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
1139+ output_no_lora = self . cached_non_lora_outputs [ scheduler_cls . __name__ ]
11221140
11231141 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
11241142 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1281,7 +1299,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self):
12811299 pipe .set_progress_bar_config (disable = None )
12821300 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
12831301
1284- output_no_lora = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
1302+ output_no_lora = self . cached_non_lora_outputs [ scheduler_cls . __name__ ]
12851303
12861304 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
12871305 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1375,7 +1393,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
13751393 pipe .set_progress_bar_config (disable = None )
13761394 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
13771395
1378- output_no_lora = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
1396+ output_no_lora = self . cached_non_lora_outputs [ scheduler_cls . __name__ ]
13791397
13801398 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
13811399 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1619,7 +1637,7 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(
16191637 pipe .set_progress_bar_config (disable = None )
16201638 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
16211639
1622- output_no_lora = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
1640+ output_no_lora = self . cached_non_lora_outputs [ scheduler_cls . __name__ ]
16231641 self .assertTrue (output_no_lora .shape == self .output_shape )
16241642
16251643 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
@@ -1700,7 +1718,7 @@ def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expec
17001718 pipe .set_progress_bar_config (disable = None )
17011719 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
17021720
1703- output_no_lora = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
1721+ output_no_lora = self . cached_non_lora_outputs [ scheduler_cls . __name__ ]
17041722 self .assertTrue (output_no_lora .shape == self .output_shape )
17051723
17061724 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
@@ -1755,7 +1773,7 @@ def test_simple_inference_with_dora(self):
17551773 pipe .set_progress_bar_config (disable = None )
17561774 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
17571775
1758- output_no_dora_lora = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
1776+ output_no_dora_lora = self . cached_non_lora_outputs [ scheduler_cls . __name__ ]
17591777 self .assertTrue (output_no_dora_lora .shape == self .output_shape )
17601778
17611779 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
@@ -1887,7 +1905,7 @@ def test_logs_info_when_no_lora_keys_found(self):
18871905 pipe .set_progress_bar_config (disable = None )
18881906
18891907 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
1890- original_out = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
1908+ original_out = self . cached_non_lora_outputs [ scheduler_cls . __name__ ]
18911909
18921910 no_op_state_dict = {"lora_foo" : torch .tensor (2.0 ), "lora_bar" : torch .tensor (3.0 )}
18931911 logger = logging .get_logger ("diffusers.loaders.peft" )
@@ -1933,7 +1951,7 @@ def test_set_adapters_match_attention_kwargs(self):
19331951 pipe .set_progress_bar_config (disable = None )
19341952 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
19351953
1936- output_no_lora = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
1954+ output_no_lora = self . cached_non_lora_outputs [ scheduler_cls . __name__ ]
19371955 self .assertTrue (output_no_lora .shape == self .output_shape )
19381956
19391957 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
@@ -2287,7 +2305,7 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
22872305 pipe = self .pipeline_class (** components ).to (torch_device )
22882306 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
22892307
2290- output_no_lora = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
2308+ output_no_lora = self . cached_non_lora_outputs ( scheduler_cls . __name__ )
22912309 self .assertTrue (output_no_lora .shape == self .output_shape )
22922310
22932311 pipe , _ = self .add_adapters_to_pipeline (
@@ -2337,7 +2355,7 @@ def test_inference_load_delete_load_adapters(self):
23372355 pipe .set_progress_bar_config (disable = None )
23382356 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
23392357
2340- output_no_lora = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
2358+ output_no_lora = self . cached_non_lora_outputs [ scheduler_cls . __name__ ]
23412359
23422360 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
23432361 pipe .text_encoder .add_adapter (text_lora_config )
0 commit comments