Skip to content

Commit 8efb5c4

Browse files
committed
up
1 parent fbcdf8b commit 8efb5c4

File tree

1 file changed

+32
-30
lines changed

1 file changed

+32
-30
lines changed

tests/lora/utils.py

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -133,16 +133,7 @@ def get_base_pipeline_outputs(self):
133133
"""
134134
This fixture is executed once per test class and caches the baseline outputs.
135135
"""
136-
components, _, _ = self.get_dummy_components(self.scheduler_cls)
137-
pipe = self.pipeline_class(**components)
138-
pipe = pipe.to(torch_device)
139-
pipe.set_progress_bar_config(disable=None)
140-
141-
# Always ensure the inputs are without the `generator`. Make sure to pass the `generator`
142-
# explicitly.
143-
_, _, inputs = self.get_dummy_inputs(with_generator=False)
144-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
145-
self.cached_non_lora_output = output_no_lora
136+
self.cached_non_lora_output = self._compute_baseline_output()
146137

147138
# Ensures that there's no inconsistency when reusing the cache.
148139
yield
@@ -260,12 +251,23 @@ def get_dummy_inputs(self, with_generator=True):
260251

261252
return noise, input_ids, pipeline_inputs
262253

263-
def get_cached_non_lora_output(self):
254+
def get_base_pipe_output(self):
264255
"""Return the cached baseline output produced without any LoRA adapters."""
265256
if self.cached_non_lora_output is None:
266-
raise ValueError("The baseline output cache is empty. Ensure the fixture has been executed.")
257+
self.cached_non_lora_output = self._compute_baseline_output()
267258
return self.cached_non_lora_output
268259

260+
def _compute_baseline_output(self):
261+
components, _, _ = self.get_dummy_components(self.scheduler_cls)
262+
pipe = self.pipeline_class(**components)
263+
pipe = pipe.to(torch_device)
264+
pipe.set_progress_bar_config(disable=None)
265+
266+
# Always ensure the inputs are without the `generator`. Make sure to pass the `generator`
267+
# explicitly.
268+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
269+
return pipe(**inputs, generator=torch.manual_seed(0))[0]
270+
269271
# Copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
270272
def get_dummy_tokens(self):
271273
max_seq_length = 77
@@ -344,7 +346,7 @@ def test_simple_inference(self):
344346
"""
345347
Tests a simple inference and makes sure it works as expected
346348
"""
347-
output_no_lora = self.get_cached_non_lora_output()
349+
output_no_lora = self.get_base_pipe_output()
348350
self.assertTrue(output_no_lora.shape == self.output_shape)
349351

350352
def test_simple_inference_with_text_lora(self):
@@ -358,7 +360,7 @@ def test_simple_inference_with_text_lora(self):
358360
pipe.set_progress_bar_config(disable=None)
359361
_, _, inputs = self.get_dummy_inputs(with_generator=False)
360362

361-
output_no_lora = self.get_cached_non_lora_output()
363+
output_no_lora = self.get_base_pipe_output()
362364
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
363365

364366
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -483,7 +485,7 @@ def test_simple_inference_with_text_lora_and_scale(self):
483485
pipe.set_progress_bar_config(disable=None)
484486
_, _, inputs = self.get_dummy_inputs(with_generator=False)
485487

486-
output_no_lora = self.get_cached_non_lora_output()
488+
output_no_lora = self.get_base_pipe_output()
487489

488490
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
489491

@@ -519,7 +521,7 @@ def test_simple_inference_with_text_lora_fused(self):
519521
pipe.set_progress_bar_config(disable=None)
520522
_, _, inputs = self.get_dummy_inputs(with_generator=False)
521523

522-
output_no_lora = self.get_cached_non_lora_output()
524+
output_no_lora = self.get_base_pipe_output()
523525

524526
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
525527

@@ -549,7 +551,7 @@ def test_simple_inference_with_text_lora_unloaded(self):
549551
pipe.set_progress_bar_config(disable=None)
550552
_, _, inputs = self.get_dummy_inputs(with_generator=False)
551553

552-
output_no_lora = self.get_cached_non_lora_output()
554+
output_no_lora = self.get_base_pipe_output()
553555

554556
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
555557

@@ -627,7 +629,7 @@ def test_simple_inference_with_partial_text_lora(self):
627629
pipe.set_progress_bar_config(disable=None)
628630
_, _, inputs = self.get_dummy_inputs(with_generator=False)
629631

630-
output_no_lora = self.get_cached_non_lora_output()
632+
output_no_lora = self.get_base_pipe_output()
631633

632634
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
633635

@@ -751,7 +753,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
751753
pipe.set_progress_bar_config(disable=None)
752754
_, _, inputs = self.get_dummy_inputs(with_generator=False)
753755

754-
output_no_lora = self.get_cached_non_lora_output()
756+
output_no_lora = self.get_base_pipe_output()
755757
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
756758

757759
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -792,7 +794,7 @@ def test_simple_inference_with_text_lora_denoiser_fused(self):
792794
pipe.set_progress_bar_config(disable=None)
793795
_, _, inputs = self.get_dummy_inputs(with_generator=False)
794796

795-
output_no_lora = self.get_cached_non_lora_output()
797+
output_no_lora = self.get_base_pipe_output()
796798

797799
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
798800

@@ -826,7 +828,7 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self):
826828
pipe.set_progress_bar_config(disable=None)
827829
_, _, inputs = self.get_dummy_inputs(with_generator=False)
828830

829-
output_no_lora = self.get_cached_non_lora_output()
831+
output_no_lora = self.get_base_pipe_output()
830832

831833
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
832834

@@ -900,7 +902,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self):
900902
pipe.set_progress_bar_config(disable=None)
901903
_, _, inputs = self.get_dummy_inputs(with_generator=False)
902904

903-
output_no_lora = self.get_cached_non_lora_output()
905+
output_no_lora = self.get_base_pipe_output()
904906

905907
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
906908
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1024,7 +1026,7 @@ def test_simple_inference_with_text_denoiser_block_scale(self):
10241026
pipe.set_progress_bar_config(disable=None)
10251027
_, _, inputs = self.get_dummy_inputs(with_generator=False)
10261028

1027-
output_no_lora = self.get_cached_non_lora_output()
1029+
output_no_lora = self.get_base_pipe_output()
10281030

10291031
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
10301032
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
@@ -1080,7 +1082,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
10801082
pipe.set_progress_bar_config(disable=None)
10811083
_, _, inputs = self.get_dummy_inputs(with_generator=False)
10821084

1083-
output_no_lora = self.get_cached_non_lora_output()
1085+
output_no_lora = self.get_base_pipe_output()
10841086

10851087
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
10861088
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1240,7 +1242,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self):
12401242
pipe.set_progress_bar_config(disable=None)
12411243
_, _, inputs = self.get_dummy_inputs(with_generator=False)
12421244

1243-
output_no_lora = self.get_cached_non_lora_output()
1245+
output_no_lora = self.get_base_pipe_output()
12441246

12451247
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
12461248
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1331,7 +1333,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
13311333
pipe.set_progress_bar_config(disable=None)
13321334
_, _, inputs = self.get_dummy_inputs(with_generator=False)
13331335

1334-
output_no_lora = self.get_cached_non_lora_output()
1336+
output_no_lora = self.get_base_pipe_output()
13351337

13361338
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
13371339
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1637,7 +1639,7 @@ def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expec
16371639
pipe.set_progress_bar_config(disable=None)
16381640
_, _, inputs = self.get_dummy_inputs(with_generator=False)
16391641

1640-
output_no_lora = self.get_cached_non_lora_output()
1642+
output_no_lora = self.get_base_pipe_output()
16411643

16421644
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
16431645
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1812,7 +1814,7 @@ def test_logs_info_when_no_lora_keys_found(self):
18121814
pipe.set_progress_bar_config(disable=None)
18131815

18141816
_, _, inputs = self.get_dummy_inputs(with_generator=False)
1815-
output_no_lora = self.get_cached_non_lora_output()
1817+
output_no_lora = self.get_base_pipe_output()
18161818

18171819
no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)}
18181820
logger = logging.get_logger("diffusers.loaders.peft")
@@ -1856,7 +1858,7 @@ def test_set_adapters_match_attention_kwargs(self):
18561858
pipe.set_progress_bar_config(disable=None)
18571859
_, _, inputs = self.get_dummy_inputs(with_generator=False)
18581860

1859-
output_no_lora = self.get_cached_non_lora_output()
1861+
output_no_lora = self.get_base_pipe_output()
18601862
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
18611863

18621864
lora_scale = 0.5
@@ -2247,7 +2249,7 @@ def test_inference_load_delete_load_adapters(self):
22472249
pipe.set_progress_bar_config(disable=None)
22482250
_, _, inputs = self.get_dummy_inputs(with_generator=False)
22492251

2250-
output_no_lora = self.get_cached_non_lora_output()
2252+
output_no_lora = self.get_base_pipe_output()
22512253

22522254
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
22532255
pipe.text_encoder.add_adapter(text_lora_config)

0 commit comments

Comments
 (0)