Skip to content

Commit ead2e04

Browse files
committed
up
1 parent 8f405ed commit ead2e04

File tree

3 files changed

+31
-26
lines changed

3 files changed

+31
-26
lines changed

tests/lora/test_lora_layers_flux.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def test_lora_expansion_works_for_absent_keys(self):
167167
pipe.set_progress_bar_config(disable=None)
168168
_, _, inputs = self.get_dummy_inputs(with_generator=False)
169169

170-
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]
170+
output_no_lora = self.get_cached_non_lora_output()
171171

172172
# Modify the config to have a layer which won't be present in the second LoRA we will load.
173173
modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
@@ -214,7 +214,7 @@ def test_lora_expansion_works_for_extra_keys(self):
214214
pipe = pipe.to(torch_device)
215215
pipe.set_progress_bar_config(disable=None)
216216
_, _, inputs = self.get_dummy_inputs(with_generator=False)
217-
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]
217+
output_no_lora = self.get_cached_non_lora_output()
218218

219219
# Modify the config to have a layer which won't be present in the first LoRA we will load.
220220
modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)

tests/lora/test_lora_layers_wanvace.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def test_lora_exclude_modules_wanvace(self):
169169
pipe = self.pipeline_class(**components).to(torch_device)
170170
_, _, inputs = self.get_dummy_inputs(with_generator=False)
171171

172-
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]
172+
output_no_lora = self.get_cached_non_lora_output()
173173
self.assertTrue(output_no_lora.shape == self.output_shape)
174174

175175
# only supported for `denoiser` now

tests/lora/utils.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,12 @@ class PeftLoraLoaderMixinTests:
126126
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
127127
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
128128

129-
cached_non_lora_outputs = {}
129+
cached_non_lora_output = None
130130

131131
@pytest.fixture(scope="class", autouse=True)
132132
def get_base_pipeline_outputs(self):
133133
"""
134-
This fixture will be executed once per test class and will populate
135-
the cached_non_lora_outputs dictionary.
134+
This fixture is executed once per test class and caches the baseline outputs.
136135
"""
137136
components, _, _ = self.get_dummy_components(self.scheduler_cls)
138137
pipe = self.pipeline_class(**components)
@@ -143,11 +142,11 @@ def get_base_pipeline_outputs(self):
143142
# explicitly.
144143
_, _, inputs = self.get_dummy_inputs(with_generator=False)
145144
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
146-
self.cached_non_lora_outputs[self.scheduler_cls.__name__] = output_no_lora
145+
self.cached_non_lora_output = output_no_lora
147146

148147
# Ensures that there's no inconsistency when reusing the cache.
149148
yield
150-
self.cached_non_lora_outputs.clear()
149+
self.cached_non_lora_output = None
151150

152151
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
153152
if self.unet_kwargs and self.transformer_kwargs:
@@ -261,6 +260,12 @@ def get_dummy_inputs(self, with_generator=True):
261260

262261
return noise, input_ids, pipeline_inputs
263262

263+
def get_cached_non_lora_output(self):
264+
"""Return the cached baseline output produced without any LoRA adapters."""
265+
if self.cached_non_lora_output is None:
266+
raise ValueError("The baseline output cache is empty. Ensure the fixture has been executed.")
267+
return self.cached_non_lora_output
268+
264269
# Copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
265270
def get_dummy_tokens(self):
266271
max_seq_length = 77
@@ -339,7 +344,7 @@ def test_simple_inference(self):
339344
"""
340345
Tests a simple inference and makes sure it works as expected
341346
"""
342-
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]
347+
output_no_lora = self.get_cached_non_lora_output()
343348
self.assertTrue(output_no_lora.shape == self.output_shape)
344349

345350
def test_simple_inference_with_text_lora(self):
@@ -353,7 +358,7 @@ def test_simple_inference_with_text_lora(self):
353358
pipe.set_progress_bar_config(disable=None)
354359
_, _, inputs = self.get_dummy_inputs(with_generator=False)
355360

356-
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]
361+
output_no_lora = self.get_cached_non_lora_output()
357362
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
358363

359364
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -478,7 +483,7 @@ def test_simple_inference_with_text_lora_and_scale(self):
478483
pipe.set_progress_bar_config(disable=None)
479484
_, _, inputs = self.get_dummy_inputs(with_generator=False)
480485

481-
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]
486+
output_no_lora = self.get_cached_non_lora_output()
482487

483488
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
484489

@@ -514,7 +519,7 @@ def test_simple_inference_with_text_lora_fused(self):
514519
pipe.set_progress_bar_config(disable=None)
515520
_, _, inputs = self.get_dummy_inputs(with_generator=False)
516521

517-
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]
522+
output_no_lora = self.get_cached_non_lora_output()
518523

519524
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
520525

@@ -544,7 +549,7 @@ def test_simple_inference_with_text_lora_unloaded(self):
544549
pipe.set_progress_bar_config(disable=None)
545550
_, _, inputs = self.get_dummy_inputs(with_generator=False)
546551

547-
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]
552+
output_no_lora = self.get_cached_non_lora_output()
548553

549554
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
550555

@@ -622,7 +627,7 @@ def test_simple_inference_with_partial_text_lora(self):
622627
pipe.set_progress_bar_config(disable=None)
623628
_, _, inputs = self.get_dummy_inputs(with_generator=False)
624629

625-
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]
630+
output_no_lora = self.get_cached_non_lora_output()
626631

627632
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
628633

@@ -746,7 +751,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
746751
pipe.set_progress_bar_config(disable=None)
747752
_, _, inputs = self.get_dummy_inputs(with_generator=False)
748753

749-
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]
754+
output_no_lora = self.get_cached_non_lora_output()
750755
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
751756

752757
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -787,7 +792,7 @@ def test_simple_inference_with_text_lora_denoiser_fused(self):
787792
pipe.set_progress_bar_config(disable=None)
788793
_, _, inputs = self.get_dummy_inputs(with_generator=False)
789794

790-
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]
795+
output_no_lora = self.get_cached_non_lora_output()
791796

792797
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
793798

@@ -821,7 +826,7 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self):
821826
pipe.set_progress_bar_config(disable=None)
822827
_, _, inputs = self.get_dummy_inputs(with_generator=False)
823828

824-
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]
829+
output_no_lora = self.get_cached_non_lora_output()
825830

826831
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
827832

@@ -895,7 +900,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self):
895900
pipe.set_progress_bar_config(disable=None)
896901
_, _, inputs = self.get_dummy_inputs(with_generator=False)
897902

898-
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]
903+
output_no_lora = self.get_cached_non_lora_output()
899904

900905
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
901906
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1019,7 +1024,7 @@ def test_simple_inference_with_text_denoiser_block_scale(self):
10191024
pipe.set_progress_bar_config(disable=None)
10201025
_, _, inputs = self.get_dummy_inputs(with_generator=False)
10211026

1022-
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]
1027+
output_no_lora = self.get_cached_non_lora_output()
10231028

10241029
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
10251030
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
@@ -1075,7 +1080,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
10751080
pipe.set_progress_bar_config(disable=None)
10761081
_, _, inputs = self.get_dummy_inputs(with_generator=False)
10771082

1078-
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]
1083+
output_no_lora = self.get_cached_non_lora_output()
10791084

10801085
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
10811086
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1235,7 +1240,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self):
12351240
pipe.set_progress_bar_config(disable=None)
12361241
_, _, inputs = self.get_dummy_inputs(with_generator=False)
12371242

1238-
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]
1243+
output_no_lora = self.get_cached_non_lora_output()
12391244

12401245
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
12411246
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1326,7 +1331,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
13261331
pipe.set_progress_bar_config(disable=None)
13271332
_, _, inputs = self.get_dummy_inputs(with_generator=False)
13281333

1329-
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]
1334+
output_no_lora = self.get_cached_non_lora_output()
13301335

13311336
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
13321337
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1632,7 +1637,7 @@ def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expec
16321637
pipe.set_progress_bar_config(disable=None)
16331638
_, _, inputs = self.get_dummy_inputs(with_generator=False)
16341639

1635-
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]
1640+
output_no_lora = self.get_cached_non_lora_output()
16361641

16371642
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
16381643
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1807,7 +1812,7 @@ def test_logs_info_when_no_lora_keys_found(self):
18071812
pipe.set_progress_bar_config(disable=None)
18081813

18091814
_, _, inputs = self.get_dummy_inputs(with_generator=False)
1810-
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]
1815+
output_no_lora = self.get_cached_non_lora_output()
18111816

18121817
no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)}
18131818
logger = logging.get_logger("diffusers.loaders.peft")
@@ -1851,7 +1856,7 @@ def test_set_adapters_match_attention_kwargs(self):
18511856
pipe.set_progress_bar_config(disable=None)
18521857
_, _, inputs = self.get_dummy_inputs(with_generator=False)
18531858

1854-
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]
1859+
output_no_lora = self.get_cached_non_lora_output()
18551860
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
18561861

18571862
lora_scale = 0.5
@@ -2242,7 +2247,7 @@ def test_inference_load_delete_load_adapters(self):
22422247
pipe.set_progress_bar_config(disable=None)
22432248
_, _, inputs = self.get_dummy_inputs(with_generator=False)
22442249

2245-
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]
2250+
output_no_lora = self.get_cached_non_lora_output()
22462251

22472252
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
22482253
pipe.text_encoder.add_adapter(text_lora_config)

0 commit comments

Comments
 (0)