Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/lora/test_lora_layers_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ def get_dummy_inputs(self, with_generator=True):
noise = floats_tensor((batch_size, num_channels) + sizes)
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)

np.random.seed(0)
pipeline_inputs = {
"prompt": "A painting of a squirrel eating a burger",
"control_image": Image.fromarray(np.random.randint(0, 255, size=(32, 32, 3), dtype="uint8")),
Expand Down
84 changes: 53 additions & 31 deletions tests/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,34 @@ class PeftLoraLoaderMixinTests:
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]

cached_non_lora_outputs = {}

@pytest.fixture(scope="class", autouse=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think this might be a leftover from previous refactoring. It's not used anywhere. Think it can be removed.

def cache_non_lora_outputs(self):
"""
This fixture will be executed once per test class and will populate
the cached_non_lora_outputs dictionary.
"""
for scheduler_cls in self.scheduler_classes:
# Check if the output for this scheduler is already cached to avoid re-running
if scheduler_cls.__name__ in self.cached_non_lora_outputs:
continue

components, _, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

# Always ensure the inputs are without the `generator`. Make sure to pass the `generator`
# explicitly.
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.cached_non_lora_outputs[scheduler_cls.__name__] = output_no_lora

# Ensures that there's no inconsistency when reusing the cache.
yield
self.cached_non_lora_outputs.clear()

def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
if self.unet_kwargs and self.transformer_kwargs:
raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
Expand Down Expand Up @@ -320,13 +348,7 @@ def test_simple_inference(self):
Tests a simple inference and makes sure it works as expected
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

_, _, inputs = self.get_dummy_inputs()
output_no_lora = pipe(**inputs)[0]
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
self.assertTrue(output_no_lora.shape == self.output_shape)

def test_simple_inference_with_text_lora(self):
Expand All @@ -341,7 +363,7 @@ def test_simple_inference_with_text_lora(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
Expand Down Expand Up @@ -424,7 +446,7 @@ def test_low_cpu_mem_usage_with_loading(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
Expand Down Expand Up @@ -480,7 +502,7 @@ def test_simple_inference_with_text_lora_and_scale(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
Expand Down Expand Up @@ -518,7 +540,7 @@ def test_simple_inference_with_text_lora_fused(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
Expand Down Expand Up @@ -550,7 +572,7 @@ def test_simple_inference_with_text_lora_unloaded(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
Expand Down Expand Up @@ -585,7 +607,7 @@ def test_simple_inference_with_text_lora_save_load(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
Expand Down Expand Up @@ -636,7 +658,7 @@ def test_simple_inference_with_partial_text_lora(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
Expand Down Expand Up @@ -687,7 +709,7 @@ def test_simple_inference_save_pretrained_with_text_lora(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
Expand Down Expand Up @@ -730,7 +752,7 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
Expand Down Expand Up @@ -771,7 +793,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
Expand Down Expand Up @@ -815,7 +837,7 @@ def test_simple_inference_with_text_lora_denoiser_fused(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
Expand Down Expand Up @@ -853,7 +875,7 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
Expand Down Expand Up @@ -932,7 +954,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]

if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
Expand Down Expand Up @@ -1061,7 +1083,7 @@ def test_simple_inference_with_text_denoiser_block_scale(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]

pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
Expand Down Expand Up @@ -1118,7 +1140,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]

if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
Expand Down Expand Up @@ -1281,7 +1303,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]

if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
Expand Down Expand Up @@ -1375,7 +1397,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]

if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
Expand Down Expand Up @@ -1619,7 +1641,7 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
self.assertTrue(output_no_lora.shape == self.output_shape)

if "text_encoder" in self.pipeline_class._lora_loadable_modules:
Expand Down Expand Up @@ -1700,7 +1722,7 @@ def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expec
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
self.assertTrue(output_no_lora.shape == self.output_shape)

if "text_encoder" in self.pipeline_class._lora_loadable_modules:
Expand Down Expand Up @@ -1755,7 +1777,7 @@ def test_simple_inference_with_dora(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_dora_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
self.assertTrue(output_no_dora_lora.shape == self.output_shape)

pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
Expand Down Expand Up @@ -1887,7 +1909,7 @@ def test_logs_info_when_no_lora_keys_found(self):
pipe.set_progress_bar_config(disable=None)

_, _, inputs = self.get_dummy_inputs(with_generator=False)
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
original_out = self.cached_non_lora_outputs[scheduler_cls.__name__]

no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)}
logger = logging.get_logger("diffusers.loaders.peft")
Expand Down Expand Up @@ -1933,7 +1955,7 @@ def test_set_adapters_match_attention_kwargs(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
Expand Down Expand Up @@ -2287,7 +2309,7 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.add_adapters_to_pipeline(
Expand Down Expand Up @@ -2337,7 +2359,7 @@ def test_inference_load_delete_load_adapters(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]

if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config)
Expand Down
Loading