Skip to content

Commit 772c32e

Browse files
committed
up
1 parent 4256de9 commit 772c32e

File tree

1 file changed

+42
-39
lines changed

1 file changed

+42
-39
lines changed

tests/lora/utils.py

Lines changed: 42 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -129,17 +129,24 @@ class PeftLoraLoaderMixinTests:
129129
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
130130
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
131131

132-
cached_non_lora_outputs = {}
132+
cached_base_pipe_outputs = {}
133+
134+
def setUp(self):
135+
super().setUp()
136+
self.get_base_pipeline_output()
137+
138+
@classmethod
139+
def tearDownClass(cls):
140+
super().tearDownClass()
141+
cls.cached_base_pipe_outputs.clear()
142+
143+
def get_base_pipeline_output(self):
144+
scheduler_names = [scheduler_cls.__name__ for scheduler_cls in self.scheduler_classes]
145+
if self.cached_base_pipe_outputs and all(k in self.cached_base_pipe_outputs for k in scheduler_names):
146+
return
133147

134-
@pytest.fixture(scope="class", autouse=True)
135-
def cache_non_lora_outputs(self):
136-
"""
137-
This fixture will be executed once per test class and will populate
138-
the cached_non_lora_outputs dictionary.
139-
"""
140148
for scheduler_cls in self.scheduler_classes:
141-
# Check if the output for this scheduler is already cached to avoid re-running
142-
if scheduler_cls.__name__ in self.cached_non_lora_outputs:
149+
if scheduler_cls.__name__ in self.cached_base_pipe_outputs:
143150
continue
144151

145152
components, _, _ = self.get_dummy_components(scheduler_cls)
@@ -151,11 +158,7 @@ def cache_non_lora_outputs(self):
151158
# explicitly.
152159
_, _, inputs = self.get_dummy_inputs(with_generator=False)
153160
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
154-
self.cached_non_lora_outputs[scheduler_cls.__name__] = output_no_lora
155-
156-
# Ensures that there's no inconsistency when reusing the cache.
157-
yield
158-
self.cached_non_lora_outputs.clear()
161+
self.cached_base_pipe_outputs[scheduler_cls.__name__] = output_no_lora
159162

160163
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
161164
if self.unet_kwargs and self.transformer_kwargs:
@@ -348,7 +351,7 @@ def test_simple_inference(self):
348351
Tests a simple inference and makes sure it works as expected
349352
"""
350353
for scheduler_cls in self.scheduler_classes:
351-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
354+
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
352355
self.assertTrue(output_no_lora.shape == self.output_shape)
353356

354357
def test_simple_inference_with_text_lora(self):
@@ -363,7 +366,7 @@ def test_simple_inference_with_text_lora(self):
363366
pipe.set_progress_bar_config(disable=None)
364367
_, _, inputs = self.get_dummy_inputs(with_generator=False)
365368

366-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
369+
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
367370
self.assertTrue(output_no_lora.shape == self.output_shape)
368371

369372
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -446,7 +449,7 @@ def test_low_cpu_mem_usage_with_loading(self):
446449
pipe.set_progress_bar_config(disable=None)
447450
_, _, inputs = self.get_dummy_inputs(with_generator=False)
448451

449-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
452+
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
450453
self.assertTrue(output_no_lora.shape == self.output_shape)
451454

452455
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -502,7 +505,7 @@ def test_simple_inference_with_text_lora_and_scale(self):
502505
pipe.set_progress_bar_config(disable=None)
503506
_, _, inputs = self.get_dummy_inputs(with_generator=False)
504507

505-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
508+
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
506509
self.assertTrue(output_no_lora.shape == self.output_shape)
507510

508511
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -540,7 +543,7 @@ def test_simple_inference_with_text_lora_fused(self):
540543
pipe.set_progress_bar_config(disable=None)
541544
_, _, inputs = self.get_dummy_inputs(with_generator=False)
542545

543-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
546+
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
544547
self.assertTrue(output_no_lora.shape == self.output_shape)
545548

546549
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -572,7 +575,7 @@ def test_simple_inference_with_text_lora_unloaded(self):
572575
pipe.set_progress_bar_config(disable=None)
573576
_, _, inputs = self.get_dummy_inputs(with_generator=False)
574577

575-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
578+
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
576579
self.assertTrue(output_no_lora.shape == self.output_shape)
577580

578581
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -607,7 +610,7 @@ def test_simple_inference_with_text_lora_save_load(self):
607610
pipe.set_progress_bar_config(disable=None)
608611
_, _, inputs = self.get_dummy_inputs(with_generator=False)
609612

610-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
613+
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
611614
self.assertTrue(output_no_lora.shape == self.output_shape)
612615

613616
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -658,7 +661,7 @@ def test_simple_inference_with_partial_text_lora(self):
658661
pipe.set_progress_bar_config(disable=None)
659662
_, _, inputs = self.get_dummy_inputs(with_generator=False)
660663

661-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
664+
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
662665
self.assertTrue(output_no_lora.shape == self.output_shape)
663666

664667
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -709,7 +712,7 @@ def test_simple_inference_save_pretrained_with_text_lora(self):
709712
pipe.set_progress_bar_config(disable=None)
710713
_, _, inputs = self.get_dummy_inputs(with_generator=False)
711714

712-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
715+
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
713716
self.assertTrue(output_no_lora.shape == self.output_shape)
714717

715718
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -752,7 +755,7 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
752755
pipe.set_progress_bar_config(disable=None)
753756
_, _, inputs = self.get_dummy_inputs(with_generator=False)
754757

755-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
758+
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
756759
self.assertTrue(output_no_lora.shape == self.output_shape)
757760

758761
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -793,7 +796,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
793796
pipe.set_progress_bar_config(disable=None)
794797
_, _, inputs = self.get_dummy_inputs(with_generator=False)
795798

796-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
799+
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
797800
self.assertTrue(output_no_lora.shape == self.output_shape)
798801

799802
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -837,7 +840,7 @@ def test_simple_inference_with_text_lora_denoiser_fused(self):
837840
pipe.set_progress_bar_config(disable=None)
838841
_, _, inputs = self.get_dummy_inputs(with_generator=False)
839842

840-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
843+
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
841844
self.assertTrue(output_no_lora.shape == self.output_shape)
842845

843846
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -875,7 +878,7 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self):
875878
pipe.set_progress_bar_config(disable=None)
876879
_, _, inputs = self.get_dummy_inputs(with_generator=False)
877880

878-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
881+
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
879882
self.assertTrue(output_no_lora.shape == self.output_shape)
880883

881884
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -954,7 +957,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self):
954957
pipe.set_progress_bar_config(disable=None)
955958
_, _, inputs = self.get_dummy_inputs(with_generator=False)
956959

957-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
960+
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
958961

959962
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
960963
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1083,7 +1086,7 @@ def test_simple_inference_with_text_denoiser_block_scale(self):
10831086
pipe.set_progress_bar_config(disable=None)
10841087
_, _, inputs = self.get_dummy_inputs(with_generator=False)
10851088

1086-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
1089+
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
10871090

10881091
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
10891092
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
@@ -1140,7 +1143,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
11401143
pipe.set_progress_bar_config(disable=None)
11411144
_, _, inputs = self.get_dummy_inputs(with_generator=False)
11421145

1143-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
1146+
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
11441147

11451148
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
11461149
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1303,7 +1306,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self):
13031306
pipe.set_progress_bar_config(disable=None)
13041307
_, _, inputs = self.get_dummy_inputs(with_generator=False)
13051308

1306-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
1309+
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
13071310

13081311
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
13091312
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1397,7 +1400,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
13971400
pipe.set_progress_bar_config(disable=None)
13981401
_, _, inputs = self.get_dummy_inputs(with_generator=False)
13991402

1400-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
1403+
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
14011404

14021405
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
14031406
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1641,7 +1644,7 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(
16411644
pipe.set_progress_bar_config(disable=None)
16421645
_, _, inputs = self.get_dummy_inputs(with_generator=False)
16431646

1644-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
1647+
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
16451648
self.assertTrue(output_no_lora.shape == self.output_shape)
16461649

16471650
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
@@ -1722,7 +1725,7 @@ def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expec
17221725
pipe.set_progress_bar_config(disable=None)
17231726
_, _, inputs = self.get_dummy_inputs(with_generator=False)
17241727

1725-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
1728+
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
17261729
self.assertTrue(output_no_lora.shape == self.output_shape)
17271730

17281731
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
@@ -1777,7 +1780,7 @@ def test_simple_inference_with_dora(self):
17771780
pipe.set_progress_bar_config(disable=None)
17781781
_, _, inputs = self.get_dummy_inputs(with_generator=False)
17791782

1780-
output_no_dora_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
1783+
output_no_dora_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
17811784
self.assertTrue(output_no_dora_lora.shape == self.output_shape)
17821785

17831786
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -1909,7 +1912,7 @@ def test_logs_info_when_no_lora_keys_found(self):
19091912
pipe.set_progress_bar_config(disable=None)
19101913

19111914
_, _, inputs = self.get_dummy_inputs(with_generator=False)
1912-
original_out = self.cached_non_lora_outputs[scheduler_cls.__name__]
1915+
original_out = self.cached_base_pipe_outputs[scheduler_cls.__name__]
19131916

19141917
no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)}
19151918
logger = logging.get_logger("diffusers.loaders.peft")
@@ -1955,7 +1958,7 @@ def test_set_adapters_match_attention_kwargs(self):
19551958
pipe.set_progress_bar_config(disable=None)
19561959
_, _, inputs = self.get_dummy_inputs(with_generator=False)
19571960

1958-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
1961+
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
19591962
self.assertTrue(output_no_lora.shape == self.output_shape)
19601963

19611964
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -2309,7 +2312,7 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
23092312
pipe = self.pipeline_class(**components).to(torch_device)
23102313
_, _, inputs = self.get_dummy_inputs(with_generator=False)
23112314

2312-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
2315+
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
23132316
self.assertTrue(output_no_lora.shape == self.output_shape)
23142317

23152318
pipe, _ = self.add_adapters_to_pipeline(
@@ -2359,7 +2362,7 @@ def test_inference_load_delete_load_adapters(self):
23592362
pipe.set_progress_bar_config(disable=None)
23602363
_, _, inputs = self.get_dummy_inputs(with_generator=False)
23612364

2362-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
2365+
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
23632366

23642367
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
23652368
pipe.text_encoder.add_adapter(text_lora_config)

0 commit comments

Comments
 (0)