Skip to content

Commit a8c5801

Browse files
committed
update
1 parent 2743c9e commit a8c5801

File tree

1 file changed

+33
-25
lines changed

1 file changed

+33
-25
lines changed

tests/lora/utils.py

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,14 @@ def get_base_pipe_outs(self):
158158
cached_base_pipe_outs.update({scheduler_cls.__name__: output_no_lora})
159159

160160
setattr(type(self), "cached_base_pipe_outs", cached_base_pipe_outs)
161+
162+
def get_base_pipeline_output(self, scheduler_cls):
163+
"""
164+
Returns the cached base pipeline output for the given scheduler.
165+
Properly handles accessing the class-level cache.
166+
"""
167+
cached_base_pipe_outs = getattr(type(self), "cached_base_pipe_outs", {})
168+
return cached_base_pipe_outs[scheduler_cls.__name__]
161169

162170
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
163171
if self.unet_kwargs and self.transformer_kwargs:
@@ -350,7 +358,7 @@ def test_simple_inference(self):
350358
Tests a simple inference and makes sure it works as expected
351359
"""
352360
for scheduler_cls in self.scheduler_classes:
353-
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
361+
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
354362
self.assertTrue(output_no_lora.shape == self.output_shape)
355363

356364
def test_simple_inference_with_text_lora(self):
@@ -365,7 +373,7 @@ def test_simple_inference_with_text_lora(self):
365373
pipe.set_progress_bar_config(disable=None)
366374
_, _, inputs = self.get_dummy_inputs(with_generator=False)
367375

368-
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
376+
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
369377
self.assertTrue(output_no_lora.shape == self.output_shape)
370378

371379
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -448,7 +456,7 @@ def test_low_cpu_mem_usage_with_loading(self):
448456
pipe.set_progress_bar_config(disable=None)
449457
_, _, inputs = self.get_dummy_inputs(with_generator=False)
450458

451-
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
459+
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
452460
self.assertTrue(output_no_lora.shape == self.output_shape)
453461

454462
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -504,7 +512,7 @@ def test_simple_inference_with_text_lora_and_scale(self):
504512
pipe.set_progress_bar_config(disable=None)
505513
_, _, inputs = self.get_dummy_inputs(with_generator=False)
506514

507-
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
515+
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
508516
self.assertTrue(output_no_lora.shape == self.output_shape)
509517

510518
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -542,7 +550,7 @@ def test_simple_inference_with_text_lora_fused(self):
542550
pipe.set_progress_bar_config(disable=None)
543551
_, _, inputs = self.get_dummy_inputs(with_generator=False)
544552

545-
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
553+
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
546554
self.assertTrue(output_no_lora.shape == self.output_shape)
547555

548556
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -574,7 +582,7 @@ def test_simple_inference_with_text_lora_unloaded(self):
574582
pipe.set_progress_bar_config(disable=None)
575583
_, _, inputs = self.get_dummy_inputs(with_generator=False)
576584

577-
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
585+
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
578586
self.assertTrue(output_no_lora.shape == self.output_shape)
579587

580588
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -609,7 +617,7 @@ def test_simple_inference_with_text_lora_save_load(self):
609617
pipe.set_progress_bar_config(disable=None)
610618
_, _, inputs = self.get_dummy_inputs(with_generator=False)
611619

612-
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
620+
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
613621
self.assertTrue(output_no_lora.shape == self.output_shape)
614622

615623
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -660,7 +668,7 @@ def test_simple_inference_with_partial_text_lora(self):
660668
pipe.set_progress_bar_config(disable=None)
661669
_, _, inputs = self.get_dummy_inputs(with_generator=False)
662670

663-
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
671+
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
664672
self.assertTrue(output_no_lora.shape == self.output_shape)
665673

666674
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -711,7 +719,7 @@ def test_simple_inference_save_pretrained_with_text_lora(self):
711719
pipe.set_progress_bar_config(disable=None)
712720
_, _, inputs = self.get_dummy_inputs(with_generator=False)
713721

714-
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
722+
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
715723
self.assertTrue(output_no_lora.shape == self.output_shape)
716724

717725
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -754,7 +762,7 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
754762
pipe.set_progress_bar_config(disable=None)
755763
_, _, inputs = self.get_dummy_inputs(with_generator=False)
756764

757-
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
765+
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
758766
self.assertTrue(output_no_lora.shape == self.output_shape)
759767

760768
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -795,7 +803,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
795803
pipe.set_progress_bar_config(disable=None)
796804
_, _, inputs = self.get_dummy_inputs(with_generator=False)
797805

798-
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
806+
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
799807
self.assertTrue(output_no_lora.shape == self.output_shape)
800808

801809
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -839,7 +847,7 @@ def test_simple_inference_with_text_lora_denoiser_fused(self):
839847
pipe.set_progress_bar_config(disable=None)
840848
_, _, inputs = self.get_dummy_inputs(with_generator=False)
841849

842-
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
850+
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
843851
self.assertTrue(output_no_lora.shape == self.output_shape)
844852

845853
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -877,7 +885,7 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self):
877885
pipe.set_progress_bar_config(disable=None)
878886
_, _, inputs = self.get_dummy_inputs(with_generator=False)
879887

880-
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
888+
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
881889
self.assertTrue(output_no_lora.shape == self.output_shape)
882890

883891
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -956,7 +964,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self):
956964
pipe.set_progress_bar_config(disable=None)
957965
_, _, inputs = self.get_dummy_inputs(with_generator=False)
958966

959-
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
967+
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
960968

961969
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
962970
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1085,7 +1093,7 @@ def test_simple_inference_with_text_denoiser_block_scale(self):
10851093
pipe.set_progress_bar_config(disable=None)
10861094
_, _, inputs = self.get_dummy_inputs(with_generator=False)
10871095

1088-
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
1096+
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
10891097

10901098
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
10911099
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
@@ -1142,7 +1150,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
11421150
pipe.set_progress_bar_config(disable=None)
11431151
_, _, inputs = self.get_dummy_inputs(with_generator=False)
11441152

1145-
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
1153+
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
11461154

11471155
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
11481156
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1305,7 +1313,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self):
13051313
pipe.set_progress_bar_config(disable=None)
13061314
_, _, inputs = self.get_dummy_inputs(with_generator=False)
13071315

1308-
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
1316+
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
13091317

13101318
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
13111319
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1399,7 +1407,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
13991407
pipe.set_progress_bar_config(disable=None)
14001408
_, _, inputs = self.get_dummy_inputs(with_generator=False)
14011409

1402-
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
1410+
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
14031411

14041412
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
14051413
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1643,7 +1651,7 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(
16431651
pipe.set_progress_bar_config(disable=None)
16441652
_, _, inputs = self.get_dummy_inputs(with_generator=False)
16451653

1646-
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
1654+
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
16471655
self.assertTrue(output_no_lora.shape == self.output_shape)
16481656

16491657
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
@@ -1724,7 +1732,7 @@ def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expec
17241732
pipe.set_progress_bar_config(disable=None)
17251733
_, _, inputs = self.get_dummy_inputs(with_generator=False)
17261734

1727-
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
1735+
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
17281736
self.assertTrue(output_no_lora.shape == self.output_shape)
17291737

17301738
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
@@ -1779,7 +1787,7 @@ def test_simple_inference_with_dora(self):
17791787
pipe.set_progress_bar_config(disable=None)
17801788
_, _, inputs = self.get_dummy_inputs(with_generator=False)
17811789

1782-
output_no_dora_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
1790+
output_no_dora_lora = self.get_base_pipeline_output(scheduler_cls)
17831791
self.assertTrue(output_no_dora_lora.shape == self.output_shape)
17841792

17851793
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -1911,7 +1919,7 @@ def test_logs_info_when_no_lora_keys_found(self):
19111919
pipe.set_progress_bar_config(disable=None)
19121920

19131921
_, _, inputs = self.get_dummy_inputs(with_generator=False)
1914-
original_out = self.cached_base_pipe_outs[scheduler_cls.__name__]
1922+
original_out = self.get_base_pipeline_output(scheduler_cls)
19151923

19161924
no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)}
19171925
logger = logging.get_logger("diffusers.loaders.peft")
@@ -1957,7 +1965,7 @@ def test_set_adapters_match_attention_kwargs(self):
19571965
pipe.set_progress_bar_config(disable=None)
19581966
_, _, inputs = self.get_dummy_inputs(with_generator=False)
19591967

1960-
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
1968+
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
19611969
self.assertTrue(output_no_lora.shape == self.output_shape)
19621970

19631971
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -2311,7 +2319,7 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
23112319
pipe = self.pipeline_class(**components).to(torch_device)
23122320
_, _, inputs = self.get_dummy_inputs(with_generator=False)
23132321

2314-
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
2322+
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
23152323
self.assertTrue(output_no_lora.shape == self.output_shape)
23162324

23172325
pipe, _ = self.add_adapters_to_pipeline(
@@ -2361,7 +2369,7 @@ def test_inference_load_delete_load_adapters(self):
23612369
pipe.set_progress_bar_config(disable=None)
23622370
_, _, inputs = self.get_dummy_inputs(with_generator=False)
23632371

2364-
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
2372+
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
23652373

23662374
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
23672375
pipe.text_encoder.add_adapter(text_lora_config)

0 commit comments

Comments
 (0)