Skip to content

Commit 2c47a2f

Browse files
committed
Revert "up"
This reverts commit 772c32e.
1 parent 772c32e commit 2c47a2f

File tree

1 file changed

+39
-42
lines changed

1 file changed

+39
-42
lines changed

tests/lora/utils.py

Lines changed: 39 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -129,24 +129,17 @@ class PeftLoraLoaderMixinTests:
129129
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
130130
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
131131

132-
cached_base_pipe_outputs = {}
133-
134-
def setUp(self):
135-
super().setUp()
136-
self.get_base_pipeline_output()
137-
138-
@classmethod
139-
def tearDownClass(cls):
140-
super().tearDownClass()
141-
cls.cached_base_pipe_outputs.clear()
142-
143-
def get_base_pipeline_output(self):
144-
scheduler_names = [scheduler_cls.__name__ for scheduler_cls in self.scheduler_classes]
145-
if self.cached_base_pipe_outputs and all(k in self.cached_base_pipe_outputs for k in scheduler_names):
146-
return
132+
cached_non_lora_outputs = {}
147133

134+
@pytest.fixture(scope="class", autouse=True)
135+
def cache_non_lora_outputs(self):
136+
"""
137+
This fixture will be executed once per test class and will populate
138+
the cached_non_lora_outputs dictionary.
139+
"""
148140
for scheduler_cls in self.scheduler_classes:
149-
if scheduler_cls.__name__ in self.cached_base_pipe_outputs:
141+
# Check if the output for this scheduler is already cached to avoid re-running
142+
if scheduler_cls.__name__ in self.cached_non_lora_outputs:
150143
continue
151144

152145
components, _, _ = self.get_dummy_components(scheduler_cls)
@@ -158,7 +151,11 @@ def get_base_pipeline_output(self):
158151
# explicitly.
159152
_, _, inputs = self.get_dummy_inputs(with_generator=False)
160153
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
161-
self.cached_base_pipe_outputs[scheduler_cls.__name__] = output_no_lora
154+
self.cached_non_lora_outputs[scheduler_cls.__name__] = output_no_lora
155+
156+
# Ensures that there's no inconsistency when reusing the cache.
157+
yield
158+
self.cached_non_lora_outputs.clear()
162159

163160
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
164161
if self.unet_kwargs and self.transformer_kwargs:
@@ -351,7 +348,7 @@ def test_simple_inference(self):
351348
Tests a simple inference and makes sure it works as expected
352349
"""
353350
for scheduler_cls in self.scheduler_classes:
354-
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
351+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
355352
self.assertTrue(output_no_lora.shape == self.output_shape)
356353

357354
def test_simple_inference_with_text_lora(self):
@@ -366,7 +363,7 @@ def test_simple_inference_with_text_lora(self):
366363
pipe.set_progress_bar_config(disable=None)
367364
_, _, inputs = self.get_dummy_inputs(with_generator=False)
368365

369-
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
366+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
370367
self.assertTrue(output_no_lora.shape == self.output_shape)
371368

372369
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -449,7 +446,7 @@ def test_low_cpu_mem_usage_with_loading(self):
449446
pipe.set_progress_bar_config(disable=None)
450447
_, _, inputs = self.get_dummy_inputs(with_generator=False)
451448

452-
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
449+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
453450
self.assertTrue(output_no_lora.shape == self.output_shape)
454451

455452
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -505,7 +502,7 @@ def test_simple_inference_with_text_lora_and_scale(self):
505502
pipe.set_progress_bar_config(disable=None)
506503
_, _, inputs = self.get_dummy_inputs(with_generator=False)
507504

508-
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
505+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
509506
self.assertTrue(output_no_lora.shape == self.output_shape)
510507

511508
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -543,7 +540,7 @@ def test_simple_inference_with_text_lora_fused(self):
543540
pipe.set_progress_bar_config(disable=None)
544541
_, _, inputs = self.get_dummy_inputs(with_generator=False)
545542

546-
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
543+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
547544
self.assertTrue(output_no_lora.shape == self.output_shape)
548545

549546
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -575,7 +572,7 @@ def test_simple_inference_with_text_lora_unloaded(self):
575572
pipe.set_progress_bar_config(disable=None)
576573
_, _, inputs = self.get_dummy_inputs(with_generator=False)
577574

578-
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
575+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
579576
self.assertTrue(output_no_lora.shape == self.output_shape)
580577

581578
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -610,7 +607,7 @@ def test_simple_inference_with_text_lora_save_load(self):
610607
pipe.set_progress_bar_config(disable=None)
611608
_, _, inputs = self.get_dummy_inputs(with_generator=False)
612609

613-
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
610+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
614611
self.assertTrue(output_no_lora.shape == self.output_shape)
615612

616613
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -661,7 +658,7 @@ def test_simple_inference_with_partial_text_lora(self):
661658
pipe.set_progress_bar_config(disable=None)
662659
_, _, inputs = self.get_dummy_inputs(with_generator=False)
663660

664-
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
661+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
665662
self.assertTrue(output_no_lora.shape == self.output_shape)
666663

667664
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -712,7 +709,7 @@ def test_simple_inference_save_pretrained_with_text_lora(self):
712709
pipe.set_progress_bar_config(disable=None)
713710
_, _, inputs = self.get_dummy_inputs(with_generator=False)
714711

715-
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
712+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
716713
self.assertTrue(output_no_lora.shape == self.output_shape)
717714

718715
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -755,7 +752,7 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
755752
pipe.set_progress_bar_config(disable=None)
756753
_, _, inputs = self.get_dummy_inputs(with_generator=False)
757754

758-
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
755+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
759756
self.assertTrue(output_no_lora.shape == self.output_shape)
760757

761758
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -796,7 +793,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
796793
pipe.set_progress_bar_config(disable=None)
797794
_, _, inputs = self.get_dummy_inputs(with_generator=False)
798795

799-
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
796+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
800797
self.assertTrue(output_no_lora.shape == self.output_shape)
801798

802799
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -840,7 +837,7 @@ def test_simple_inference_with_text_lora_denoiser_fused(self):
840837
pipe.set_progress_bar_config(disable=None)
841838
_, _, inputs = self.get_dummy_inputs(with_generator=False)
842839

843-
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
840+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
844841
self.assertTrue(output_no_lora.shape == self.output_shape)
845842

846843
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -878,7 +875,7 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self):
878875
pipe.set_progress_bar_config(disable=None)
879876
_, _, inputs = self.get_dummy_inputs(with_generator=False)
880877

881-
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
878+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
882879
self.assertTrue(output_no_lora.shape == self.output_shape)
883880

884881
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -957,7 +954,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self):
957954
pipe.set_progress_bar_config(disable=None)
958955
_, _, inputs = self.get_dummy_inputs(with_generator=False)
959956

960-
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
957+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
961958

962959
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
963960
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1086,7 +1083,7 @@ def test_simple_inference_with_text_denoiser_block_scale(self):
10861083
pipe.set_progress_bar_config(disable=None)
10871084
_, _, inputs = self.get_dummy_inputs(with_generator=False)
10881085

1089-
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
1086+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
10901087

10911088
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
10921089
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
@@ -1143,7 +1140,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
11431140
pipe.set_progress_bar_config(disable=None)
11441141
_, _, inputs = self.get_dummy_inputs(with_generator=False)
11451142

1146-
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
1143+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
11471144

11481145
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
11491146
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1306,7 +1303,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self):
13061303
pipe.set_progress_bar_config(disable=None)
13071304
_, _, inputs = self.get_dummy_inputs(with_generator=False)
13081305

1309-
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
1306+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
13101307

13111308
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
13121309
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1400,7 +1397,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
14001397
pipe.set_progress_bar_config(disable=None)
14011398
_, _, inputs = self.get_dummy_inputs(with_generator=False)
14021399

1403-
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
1400+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
14041401

14051402
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
14061403
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1644,7 +1641,7 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(
16441641
pipe.set_progress_bar_config(disable=None)
16451642
_, _, inputs = self.get_dummy_inputs(with_generator=False)
16461643

1647-
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
1644+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
16481645
self.assertTrue(output_no_lora.shape == self.output_shape)
16491646

16501647
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
@@ -1725,7 +1722,7 @@ def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expec
17251722
pipe.set_progress_bar_config(disable=None)
17261723
_, _, inputs = self.get_dummy_inputs(with_generator=False)
17271724

1728-
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
1725+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
17291726
self.assertTrue(output_no_lora.shape == self.output_shape)
17301727

17311728
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
@@ -1780,7 +1777,7 @@ def test_simple_inference_with_dora(self):
17801777
pipe.set_progress_bar_config(disable=None)
17811778
_, _, inputs = self.get_dummy_inputs(with_generator=False)
17821779

1783-
output_no_dora_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
1780+
output_no_dora_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
17841781
self.assertTrue(output_no_dora_lora.shape == self.output_shape)
17851782

17861783
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -1912,7 +1909,7 @@ def test_logs_info_when_no_lora_keys_found(self):
19121909
pipe.set_progress_bar_config(disable=None)
19131910

19141911
_, _, inputs = self.get_dummy_inputs(with_generator=False)
1915-
original_out = self.cached_base_pipe_outputs[scheduler_cls.__name__]
1912+
original_out = self.cached_non_lora_outputs[scheduler_cls.__name__]
19161913

19171914
no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)}
19181915
logger = logging.get_logger("diffusers.loaders.peft")
@@ -1958,7 +1955,7 @@ def test_set_adapters_match_attention_kwargs(self):
19581955
pipe.set_progress_bar_config(disable=None)
19591956
_, _, inputs = self.get_dummy_inputs(with_generator=False)
19601957

1961-
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
1958+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
19621959
self.assertTrue(output_no_lora.shape == self.output_shape)
19631960

19641961
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -2312,7 +2309,7 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
23122309
pipe = self.pipeline_class(**components).to(torch_device)
23132310
_, _, inputs = self.get_dummy_inputs(with_generator=False)
23142311

2315-
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
2312+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
23162313
self.assertTrue(output_no_lora.shape == self.output_shape)
23172314

23182315
pipe, _ = self.add_adapters_to_pipeline(
@@ -2362,7 +2359,7 @@ def test_inference_load_delete_load_adapters(self):
23622359
pipe.set_progress_bar_config(disable=None)
23632360
_, _, inputs = self.get_dummy_inputs(with_generator=False)
23642361

2365-
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
2362+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
23662363

23672364
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
23682365
pipe.text_encoder.add_adapter(text_lora_config)

0 commit comments

Comments
 (0)