Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions tests/lora/test_lora_layers_cogview4.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,6 @@ def test_simple_inference_save_pretrained(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]

with tempfile.TemporaryDirectory() as tmpdirname:
Expand Down
11 changes: 3 additions & 8 deletions tests/lora/test_lora_layers_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,6 @@ def test_with_alpha_in_state_dict(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe.transformer.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")

Expand Down Expand Up @@ -170,8 +167,7 @@ def test_lora_expansion_works_for_absent_keys(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(output_no_lora.shape == self.output_shape)
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]

# Modify the config to have a layer which won't be present in the second LoRA we will load.
modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
Expand Down Expand Up @@ -218,9 +214,7 @@ def test_lora_expansion_works_for_extra_keys(self):
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(output_no_lora.shape == self.output_shape)
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]

# Modify the config to have a layer which won't be present in the first LoRA we will load.
modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
Expand Down Expand Up @@ -329,6 +323,7 @@ def get_dummy_inputs(self, with_generator=True):
noise = floats_tensor((batch_size, num_channels) + sizes)
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)

np.random.seed(0)
pipeline_inputs = {
"prompt": "A painting of a squirrel eating a burger",
"control_image": Image.fromarray(np.random.randint(0, 255, size=(32, 32, 3), dtype="uint8")),
Expand Down
2 changes: 1 addition & 1 deletion tests/lora/test_lora_layers_wanvace.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def test_lora_exclude_modules_wanvace(self):
pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]
self.assertTrue(output_no_lora.shape == self.output_shape)

# only supported for `denoiser` now
Expand Down
106 changes: 44 additions & 62 deletions tests/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,36 @@ class PeftLoraLoaderMixinTests:
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]

def get_dummy_components(self, use_dora=False, lora_alpha=None):
cached_non_lora_outputs = {}

@pytest.fixture(scope="class", autouse=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think this might be a leftover from previous refactoring. It's not used anywhere. Think it can be removed.

def get_base_pipeline_outputs(self):
"""
This fixture will be executed once per test class and will populate
the cached_non_lora_outputs dictionary.
"""
components, _, _ = self.get_dummy_components(self.scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

# Always ensure the inputs are without the `generator`. Make sure to pass the `generator`
# explicitly.
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.cached_non_lora_outputs[self.scheduler_cls.__name__] = output_no_lora

# Ensures that there's no inconsistency when reusing the cache.
yield
self.cached_non_lora_outputs.clear()

def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
if self.unet_kwargs and self.transformer_kwargs:
raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
if self.has_two_text_encoders and self.has_three_text_encoders:
raise ValueError("Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True.")

scheduler_cls = self.scheduler_cls
scheduler_cls = scheduler_cls if scheduler_cls is not None else self.scheduler_cls
rank = 4
lora_alpha = rank if lora_alpha is None else lora_alpha

Expand Down Expand Up @@ -316,13 +339,7 @@ def test_simple_inference(self):
"""
Tests a simple inference and makes sure it works as expected
"""
components, text_lora_config, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

_, _, inputs = self.get_dummy_inputs()
output_no_lora = pipe(**inputs)[0]
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]
self.assertTrue(output_no_lora.shape == self.output_shape)

def test_simple_inference_with_text_lora(self):
Expand All @@ -336,9 +353,7 @@ def test_simple_inference_with_text_lora(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)

output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Expand Down Expand Up @@ -414,9 +429,6 @@ def test_low_cpu_mem_usage_with_loading(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)

images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Expand Down Expand Up @@ -466,8 +478,7 @@ def test_simple_inference_with_text_lora_and_scale(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]

pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)

Expand Down Expand Up @@ -503,8 +514,7 @@ def test_simple_inference_with_text_lora_fused(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]

pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)

Expand Down Expand Up @@ -534,8 +544,7 @@ def test_simple_inference_with_text_lora_unloaded(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]

pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)

Expand Down Expand Up @@ -566,9 +575,6 @@ def test_simple_inference_with_text_lora_save_load(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)

images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Expand Down Expand Up @@ -616,8 +622,7 @@ def test_simple_inference_with_partial_text_lora(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]

pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)

Expand Down Expand Up @@ -666,9 +671,6 @@ def test_simple_inference_save_pretrained_with_text_lora(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]

Expand Down Expand Up @@ -708,9 +710,6 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)

images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Expand Down Expand Up @@ -747,9 +746,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)

output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Expand Down Expand Up @@ -790,8 +787,7 @@ def test_simple_inference_with_text_lora_denoiser_fused(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]

pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)

Expand Down Expand Up @@ -825,8 +821,7 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]

pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)

Expand Down Expand Up @@ -900,7 +895,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]

if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
Expand Down Expand Up @@ -1024,7 +1019,7 @@ def test_simple_inference_with_text_denoiser_block_scale(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]

pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
Expand Down Expand Up @@ -1080,7 +1075,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]

if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
Expand Down Expand Up @@ -1240,7 +1235,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]

if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
Expand Down Expand Up @@ -1331,7 +1326,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]

if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
Expand Down Expand Up @@ -1551,7 +1546,6 @@ def test_get_list_adapters(self):

self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)

@require_peft_version_greater(peft_version="0.6.2")
def test_simple_inference_with_text_lora_denoiser_fused_multi(
self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3
):
Expand All @@ -1565,9 +1559,6 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
Expand Down Expand Up @@ -1641,8 +1632,7 @@ def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expec
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]

if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
Expand Down Expand Up @@ -1685,7 +1675,6 @@ def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expec
"LoRA should change the output",
)

@require_peft_version_greater(peft_version="0.9.0")
def test_simple_inference_with_dora(self):
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(use_dora=True)
pipe = self.pipeline_class(**components)
Expand All @@ -1695,7 +1684,6 @@ def test_simple_inference_with_dora(self):

output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_dora_lora.shape == self.output_shape)

pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)

output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Expand Down Expand Up @@ -1783,7 +1771,6 @@ def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self):
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)

pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
Expand Down Expand Up @@ -1820,7 +1807,7 @@ def test_logs_info_when_no_lora_keys_found(self):
pipe.set_progress_bar_config(disable=None)

_, _, inputs = self.get_dummy_inputs(with_generator=False)
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]

no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)}
logger = logging.get_logger("diffusers.loaders.peft")
Expand All @@ -1832,7 +1819,7 @@ def test_logs_info_when_no_lora_keys_found(self):

denoiser = getattr(pipe, "unet") if self.unet_kwargs is not None else getattr(pipe, "transformer")
self.assertTrue(cap_logger.out.startswith(f"No LoRA keys associated to {denoiser.__class__.__name__}"))
self.assertTrue(np.allclose(original_out, out_after_lora_attempt, atol=1e-5, rtol=1e-5))
self.assertTrue(np.allclose(output_no_lora, out_after_lora_attempt, atol=1e-5, rtol=1e-5))

# test only for text encoder
for lora_module in self.pipeline_class._lora_loadable_modules:
Expand Down Expand Up @@ -1864,9 +1851,7 @@ def test_set_adapters_match_attention_kwargs(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)

lora_scale = 0.5
Expand Down Expand Up @@ -2212,9 +2197,6 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
Expand Down Expand Up @@ -2260,7 +2242,7 @@ def test_inference_load_delete_load_adapters(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.cached_non_lora_outputs[self.scheduler_cls.__name__]

if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config)
Expand Down