Skip to content

Commit 2743c9e

Browse files
committed
update
1 parent 2c47a2f commit 2743c9e

File tree

1 file changed

+40
-38
lines changed

1 file changed

+40
-38
lines changed

tests/lora/utils.py

Lines changed: 40 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -129,17 +129,21 @@ class PeftLoraLoaderMixinTests:
129129
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
130130
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
131131

132-
cached_non_lora_outputs = {}
132+
cached_base_pipe_outs = {}
133133

134-
@pytest.fixture(scope="class", autouse=True)
135-
def cache_non_lora_outputs(self):
136-
"""
137-
This fixture will be executed once per test class and will populate
138-
the cached_non_lora_outputs dictionary.
139-
"""
134+
def setUp(self):
135+
self.get_base_pipe_outs()
136+
super().setUp()
137+
138+
def get_base_pipe_outs(self):
139+
cached_base_pipe_outs = getattr(type(self), "cached_base_pipe_outs", {})
140+
all_scheduler_names = [scheduler_cls.__name__ for scheduler_cls in self.scheduler_classes]
141+
if cached_base_pipe_outs is not None and all(k in cached_base_pipe_outs for k in all_scheduler_names):
142+
return
143+
144+
cached_base_pipe_outs = cached_base_pipe_outs or {}
140145
for scheduler_cls in self.scheduler_classes:
141-
# Check if the output for this scheduler is already cached to avoid re-running
142-
if scheduler_cls.__name__ in self.cached_non_lora_outputs:
146+
if scheduler_cls.__name__ in cached_base_pipe_outs:
143147
continue
144148

145149
components, _, _ = self.get_dummy_components(scheduler_cls)
@@ -151,11 +155,9 @@ def cache_non_lora_outputs(self):
151155
# explicitly.
152156
_, _, inputs = self.get_dummy_inputs(with_generator=False)
153157
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
154-
self.cached_non_lora_outputs[scheduler_cls.__name__] = output_no_lora
158+
cached_base_pipe_outs.update({scheduler_cls.__name__: output_no_lora})
155159

156-
# Ensures that there's no inconsistency when reusing the cache.
157-
yield
158-
self.cached_non_lora_outputs.clear()
160+
setattr(type(self), "cached_base_pipe_outs", cached_base_pipe_outs)
159161

160162
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
161163
if self.unet_kwargs and self.transformer_kwargs:
@@ -348,7 +350,7 @@ def test_simple_inference(self):
348350
Tests a simple inference and makes sure it works as expected
349351
"""
350352
for scheduler_cls in self.scheduler_classes:
351-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
353+
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
352354
self.assertTrue(output_no_lora.shape == self.output_shape)
353355

354356
def test_simple_inference_with_text_lora(self):
@@ -363,7 +365,7 @@ def test_simple_inference_with_text_lora(self):
363365
pipe.set_progress_bar_config(disable=None)
364366
_, _, inputs = self.get_dummy_inputs(with_generator=False)
365367

366-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
368+
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
367369
self.assertTrue(output_no_lora.shape == self.output_shape)
368370

369371
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -446,7 +448,7 @@ def test_low_cpu_mem_usage_with_loading(self):
446448
pipe.set_progress_bar_config(disable=None)
447449
_, _, inputs = self.get_dummy_inputs(with_generator=False)
448450

449-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
451+
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
450452
self.assertTrue(output_no_lora.shape == self.output_shape)
451453

452454
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -502,7 +504,7 @@ def test_simple_inference_with_text_lora_and_scale(self):
502504
pipe.set_progress_bar_config(disable=None)
503505
_, _, inputs = self.get_dummy_inputs(with_generator=False)
504506

505-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
507+
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
506508
self.assertTrue(output_no_lora.shape == self.output_shape)
507509

508510
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -540,7 +542,7 @@ def test_simple_inference_with_text_lora_fused(self):
540542
pipe.set_progress_bar_config(disable=None)
541543
_, _, inputs = self.get_dummy_inputs(with_generator=False)
542544

543-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
545+
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
544546
self.assertTrue(output_no_lora.shape == self.output_shape)
545547

546548
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -572,7 +574,7 @@ def test_simple_inference_with_text_lora_unloaded(self):
572574
pipe.set_progress_bar_config(disable=None)
573575
_, _, inputs = self.get_dummy_inputs(with_generator=False)
574576

575-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
577+
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
576578
self.assertTrue(output_no_lora.shape == self.output_shape)
577579

578580
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -607,7 +609,7 @@ def test_simple_inference_with_text_lora_save_load(self):
607609
pipe.set_progress_bar_config(disable=None)
608610
_, _, inputs = self.get_dummy_inputs(with_generator=False)
609611

610-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
612+
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
611613
self.assertTrue(output_no_lora.shape == self.output_shape)
612614

613615
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -658,7 +660,7 @@ def test_simple_inference_with_partial_text_lora(self):
658660
pipe.set_progress_bar_config(disable=None)
659661
_, _, inputs = self.get_dummy_inputs(with_generator=False)
660662

661-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
663+
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
662664
self.assertTrue(output_no_lora.shape == self.output_shape)
663665

664666
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -709,7 +711,7 @@ def test_simple_inference_save_pretrained_with_text_lora(self):
709711
pipe.set_progress_bar_config(disable=None)
710712
_, _, inputs = self.get_dummy_inputs(with_generator=False)
711713

712-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
714+
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
713715
self.assertTrue(output_no_lora.shape == self.output_shape)
714716

715717
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -752,7 +754,7 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
752754
pipe.set_progress_bar_config(disable=None)
753755
_, _, inputs = self.get_dummy_inputs(with_generator=False)
754756

755-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
757+
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
756758
self.assertTrue(output_no_lora.shape == self.output_shape)
757759

758760
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -793,7 +795,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
793795
pipe.set_progress_bar_config(disable=None)
794796
_, _, inputs = self.get_dummy_inputs(with_generator=False)
795797

796-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
798+
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
797799
self.assertTrue(output_no_lora.shape == self.output_shape)
798800

799801
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -837,7 +839,7 @@ def test_simple_inference_with_text_lora_denoiser_fused(self):
837839
pipe.set_progress_bar_config(disable=None)
838840
_, _, inputs = self.get_dummy_inputs(with_generator=False)
839841

840-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
842+
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
841843
self.assertTrue(output_no_lora.shape == self.output_shape)
842844

843845
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -875,7 +877,7 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self):
875877
pipe.set_progress_bar_config(disable=None)
876878
_, _, inputs = self.get_dummy_inputs(with_generator=False)
877879

878-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
880+
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
879881
self.assertTrue(output_no_lora.shape == self.output_shape)
880882

881883
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -954,7 +956,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self):
954956
pipe.set_progress_bar_config(disable=None)
955957
_, _, inputs = self.get_dummy_inputs(with_generator=False)
956958

957-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
959+
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
958960

959961
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
960962
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1083,7 +1085,7 @@ def test_simple_inference_with_text_denoiser_block_scale(self):
10831085
pipe.set_progress_bar_config(disable=None)
10841086
_, _, inputs = self.get_dummy_inputs(with_generator=False)
10851087

1086-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
1088+
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
10871089

10881090
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
10891091
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
@@ -1140,7 +1142,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
11401142
pipe.set_progress_bar_config(disable=None)
11411143
_, _, inputs = self.get_dummy_inputs(with_generator=False)
11421144

1143-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
1145+
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
11441146

11451147
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
11461148
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1303,7 +1305,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self):
13031305
pipe.set_progress_bar_config(disable=None)
13041306
_, _, inputs = self.get_dummy_inputs(with_generator=False)
13051307

1306-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
1308+
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
13071309

13081310
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
13091311
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1397,7 +1399,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
13971399
pipe.set_progress_bar_config(disable=None)
13981400
_, _, inputs = self.get_dummy_inputs(with_generator=False)
13991401

1400-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
1402+
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
14011403

14021404
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
14031405
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1641,7 +1643,7 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(
16411643
pipe.set_progress_bar_config(disable=None)
16421644
_, _, inputs = self.get_dummy_inputs(with_generator=False)
16431645

1644-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
1646+
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
16451647
self.assertTrue(output_no_lora.shape == self.output_shape)
16461648

16471649
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
@@ -1722,7 +1724,7 @@ def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expec
17221724
pipe.set_progress_bar_config(disable=None)
17231725
_, _, inputs = self.get_dummy_inputs(with_generator=False)
17241726

1725-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
1727+
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
17261728
self.assertTrue(output_no_lora.shape == self.output_shape)
17271729

17281730
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
@@ -1777,7 +1779,7 @@ def test_simple_inference_with_dora(self):
17771779
pipe.set_progress_bar_config(disable=None)
17781780
_, _, inputs = self.get_dummy_inputs(with_generator=False)
17791781

1780-
output_no_dora_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
1782+
output_no_dora_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
17811783
self.assertTrue(output_no_dora_lora.shape == self.output_shape)
17821784

17831785
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -1909,7 +1911,7 @@ def test_logs_info_when_no_lora_keys_found(self):
19091911
pipe.set_progress_bar_config(disable=None)
19101912

19111913
_, _, inputs = self.get_dummy_inputs(with_generator=False)
1912-
original_out = self.cached_non_lora_outputs[scheduler_cls.__name__]
1914+
original_out = self.cached_base_pipe_outs[scheduler_cls.__name__]
19131915

19141916
no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)}
19151917
logger = logging.get_logger("diffusers.loaders.peft")
@@ -1955,7 +1957,7 @@ def test_set_adapters_match_attention_kwargs(self):
19551957
pipe.set_progress_bar_config(disable=None)
19561958
_, _, inputs = self.get_dummy_inputs(with_generator=False)
19571959

1958-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
1960+
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
19591961
self.assertTrue(output_no_lora.shape == self.output_shape)
19601962

19611963
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -2309,7 +2311,7 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
23092311
pipe = self.pipeline_class(**components).to(torch_device)
23102312
_, _, inputs = self.get_dummy_inputs(with_generator=False)
23112313

2312-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
2314+
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
23132315
self.assertTrue(output_no_lora.shape == self.output_shape)
23142316

23152317
pipe, _ = self.add_adapters_to_pipeline(
@@ -2359,7 +2361,7 @@ def test_inference_load_delete_load_adapters(self):
23592361
pipe.set_progress_bar_config(disable=None)
23602362
_, _, inputs = self.get_dummy_inputs(with_generator=False)
23612363

2362-
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
2364+
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
23632365

23642366
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
23652367
pipe.text_encoder.add_adapter(text_lora_config)

0 commit comments

Comments
 (0)