Skip to content

Commit 02fd92e

Browse files
committed
cache non lora pipeline outputs.
1 parent fc337d5 commit 02fd92e

File tree

1 file changed

+49
-31
lines changed

1 file changed

+49
-31
lines changed

tests/lora/utils.py

Lines changed: 49 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,30 @@ class PeftLoraLoaderMixinTests:
129129
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
130130
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
131131

132+
cached_non_lora_outputs = {}
133+
134+
@pytest.fixture(scope="class", autouse=True)
135+
def cache_non_lora_outputs(self, request):
136+
"""
137+
This fixture will be executed once per test class and will populate
138+
the cached_non_lora_outputs dictionary.
139+
"""
140+
for scheduler_cls in self.scheduler_classes:
141+
# Check if the output for this scheduler is already cached to avoid re-running
142+
if scheduler_cls.__name__ in self.cached_non_lora_outputs:
143+
continue
144+
145+
components, _, _ = self.get_dummy_components(scheduler_cls)
146+
pipe = self.pipeline_class(**components)
147+
pipe = pipe.to(torch_device)
148+
pipe.set_progress_bar_config(disable=None)
149+
150+
# Always ensure the inputs are without the `generator`. Make sure to pass the `generator`
151+
# explicitly.
152+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
153+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
154+
self.cached_non_lora_outputs[scheduler_cls.__name__] = output_no_lora
155+
132156
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
133157
if self.unet_kwargs and self.transformer_kwargs:
134158
raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
@@ -320,13 +344,7 @@ def test_simple_inference(self):
320344
Tests a simple inference and makes sure it works as expected
321345
"""
322346
for scheduler_cls in self.scheduler_classes:
323-
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
324-
pipe = self.pipeline_class(**components)
325-
pipe = pipe.to(torch_device)
326-
pipe.set_progress_bar_config(disable=None)
327-
328-
_, _, inputs = self.get_dummy_inputs()
329-
output_no_lora = pipe(**inputs)[0]
347+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
330348
self.assertTrue(output_no_lora.shape == self.output_shape)
331349

332350
def test_simple_inference_with_text_lora(self):
@@ -341,7 +359,7 @@ def test_simple_inference_with_text_lora(self):
341359
pipe.set_progress_bar_config(disable=None)
342360
_, _, inputs = self.get_dummy_inputs(with_generator=False)
343361

344-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
362+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
345363
self.assertTrue(output_no_lora.shape == self.output_shape)
346364

347365
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -424,7 +442,7 @@ def test_low_cpu_mem_usage_with_loading(self):
424442
pipe.set_progress_bar_config(disable=None)
425443
_, _, inputs = self.get_dummy_inputs(with_generator=False)
426444

427-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
445+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
428446
self.assertTrue(output_no_lora.shape == self.output_shape)
429447

430448
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -480,7 +498,7 @@ def test_simple_inference_with_text_lora_and_scale(self):
480498
pipe.set_progress_bar_config(disable=None)
481499
_, _, inputs = self.get_dummy_inputs(with_generator=False)
482500

483-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
501+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
484502
self.assertTrue(output_no_lora.shape == self.output_shape)
485503

486504
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -518,7 +536,7 @@ def test_simple_inference_with_text_lora_fused(self):
518536
pipe.set_progress_bar_config(disable=None)
519537
_, _, inputs = self.get_dummy_inputs(with_generator=False)
520538

521-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
539+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
522540
self.assertTrue(output_no_lora.shape == self.output_shape)
523541

524542
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -550,7 +568,7 @@ def test_simple_inference_with_text_lora_unloaded(self):
550568
pipe.set_progress_bar_config(disable=None)
551569
_, _, inputs = self.get_dummy_inputs(with_generator=False)
552570

553-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
571+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
554572
self.assertTrue(output_no_lora.shape == self.output_shape)
555573

556574
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -585,7 +603,7 @@ def test_simple_inference_with_text_lora_save_load(self):
585603
pipe.set_progress_bar_config(disable=None)
586604
_, _, inputs = self.get_dummy_inputs(with_generator=False)
587605

588-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
606+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
589607
self.assertTrue(output_no_lora.shape == self.output_shape)
590608

591609
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -636,7 +654,7 @@ def test_simple_inference_with_partial_text_lora(self):
636654
pipe.set_progress_bar_config(disable=None)
637655
_, _, inputs = self.get_dummy_inputs(with_generator=False)
638656

639-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
657+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
640658
self.assertTrue(output_no_lora.shape == self.output_shape)
641659

642660
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -687,7 +705,7 @@ def test_simple_inference_save_pretrained_with_text_lora(self):
687705
pipe.set_progress_bar_config(disable=None)
688706
_, _, inputs = self.get_dummy_inputs(with_generator=False)
689707

690-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
708+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
691709
self.assertTrue(output_no_lora.shape == self.output_shape)
692710

693711
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -730,7 +748,7 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
730748
pipe.set_progress_bar_config(disable=None)
731749
_, _, inputs = self.get_dummy_inputs(with_generator=False)
732750

733-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
751+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
734752
self.assertTrue(output_no_lora.shape == self.output_shape)
735753

736754
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -771,7 +789,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
771789
pipe.set_progress_bar_config(disable=None)
772790
_, _, inputs = self.get_dummy_inputs(with_generator=False)
773791

774-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
792+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
775793
self.assertTrue(output_no_lora.shape == self.output_shape)
776794

777795
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -815,7 +833,7 @@ def test_simple_inference_with_text_lora_denoiser_fused(self):
815833
pipe.set_progress_bar_config(disable=None)
816834
_, _, inputs = self.get_dummy_inputs(with_generator=False)
817835

818-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
836+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
819837
self.assertTrue(output_no_lora.shape == self.output_shape)
820838

821839
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -853,7 +871,7 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self):
853871
pipe.set_progress_bar_config(disable=None)
854872
_, _, inputs = self.get_dummy_inputs(with_generator=False)
855873

856-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
874+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
857875
self.assertTrue(output_no_lora.shape == self.output_shape)
858876

859877
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -932,7 +950,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self):
932950
pipe.set_progress_bar_config(disable=None)
933951
_, _, inputs = self.get_dummy_inputs(with_generator=False)
934952

935-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
953+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
936954

937955
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
938956
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1061,7 +1079,7 @@ def test_simple_inference_with_text_denoiser_block_scale(self):
10611079
pipe.set_progress_bar_config(disable=None)
10621080
_, _, inputs = self.get_dummy_inputs(with_generator=False)
10631081

1064-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1082+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
10651083

10661084
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
10671085
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
@@ -1118,7 +1136,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
11181136
pipe.set_progress_bar_config(disable=None)
11191137
_, _, inputs = self.get_dummy_inputs(with_generator=False)
11201138

1121-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1139+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
11221140

11231141
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
11241142
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1281,7 +1299,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self):
12811299
pipe.set_progress_bar_config(disable=None)
12821300
_, _, inputs = self.get_dummy_inputs(with_generator=False)
12831301

1284-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1302+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
12851303

12861304
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
12871305
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1375,7 +1393,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
13751393
pipe.set_progress_bar_config(disable=None)
13761394
_, _, inputs = self.get_dummy_inputs(with_generator=False)
13771395

1378-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1396+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
13791397

13801398
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
13811399
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1619,7 +1637,7 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(
16191637
pipe.set_progress_bar_config(disable=None)
16201638
_, _, inputs = self.get_dummy_inputs(with_generator=False)
16211639

1622-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1640+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
16231641
self.assertTrue(output_no_lora.shape == self.output_shape)
16241642

16251643
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
@@ -1700,7 +1718,7 @@ def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expec
17001718
pipe.set_progress_bar_config(disable=None)
17011719
_, _, inputs = self.get_dummy_inputs(with_generator=False)
17021720

1703-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1721+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
17041722
self.assertTrue(output_no_lora.shape == self.output_shape)
17051723

17061724
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
@@ -1755,7 +1773,7 @@ def test_simple_inference_with_dora(self):
17551773
pipe.set_progress_bar_config(disable=None)
17561774
_, _, inputs = self.get_dummy_inputs(with_generator=False)
17571775

1758-
output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1776+
output_no_dora_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
17591777
self.assertTrue(output_no_dora_lora.shape == self.output_shape)
17601778

17611779
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -1887,7 +1905,7 @@ def test_logs_info_when_no_lora_keys_found(self):
18871905
pipe.set_progress_bar_config(disable=None)
18881906

18891907
_, _, inputs = self.get_dummy_inputs(with_generator=False)
1890-
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
1908+
original_out = self.cached_non_lora_outputs[scheduler_cls.__name__]
18911909

18921910
no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)}
18931911
logger = logging.get_logger("diffusers.loaders.peft")
@@ -1933,7 +1951,7 @@ def test_set_adapters_match_attention_kwargs(self):
19331951
pipe.set_progress_bar_config(disable=None)
19341952
_, _, inputs = self.get_dummy_inputs(with_generator=False)
19351953

1936-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1954+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
19371955
self.assertTrue(output_no_lora.shape == self.output_shape)
19381956

19391957
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -2287,7 +2305,7 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
22872305
pipe = self.pipeline_class(**components).to(torch_device)
22882306
_, _, inputs = self.get_dummy_inputs(with_generator=False)
22892307

2290-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
2308+
output_no_lora = self.cached_non_lora_outputs(scheduler_cls.__name__)
22912309
self.assertTrue(output_no_lora.shape == self.output_shape)
22922310

22932311
pipe, _ = self.add_adapters_to_pipeline(
@@ -2337,7 +2355,7 @@ def test_inference_load_delete_load_adapters(self):
23372355
pipe.set_progress_bar_config(disable=None)
23382356
_, _, inputs = self.get_dummy_inputs(with_generator=False)
23392357

2340-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
2358+
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
23412359

23422360
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
23432361
pipe.text_encoder.add_adapter(text_lora_config)

0 commit comments

Comments
 (0)