Skip to content

Commit c34fc34

Browse files
authored
[Tests] QoL improvements to the LoRA test suite (#10304)
* misc lora test improvements. * updates * fixes to tests
1 parent 5fcee4a commit c34fc34

File tree

3 files changed

+132
-118
lines changed

3 files changed

+132
-118
lines changed

tests/lora/test_lora_layers_flux.py

Lines changed: 20 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
numpy_cosine_similarity_distance,
3737
require_big_gpu_with_torch_cuda,
3838
require_peft_backend,
39-
require_peft_version_greater,
4039
require_torch_gpu,
4140
slow,
4241
torch_device,
@@ -331,7 +330,8 @@ def test_lora_parameter_expanded_shapes(self):
331330
}
332331
with CaptureLogger(logger) as cap_logger:
333332
pipe.load_lora_weights(lora_state_dict, "adapter-1")
334-
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
333+
334+
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
335335

336336
lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
337337

@@ -340,85 +340,32 @@ def test_lora_parameter_expanded_shapes(self):
340340
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
341341
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
342342

343-
@require_peft_version_greater("0.13.2")
344-
def test_lora_B_bias(self):
345-
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
346-
pipe = self.pipeline_class(**components)
347-
pipe = pipe.to(torch_device)
348-
pipe.set_progress_bar_config(disable=None)
349-
350-
# keep track of the bias values of the base layers to perform checks later.
351-
bias_values = {}
352-
for name, module in pipe.transformer.named_modules():
353-
if any(k in name for k in ["to_q", "to_k", "to_v", "to_out.0"]):
354-
if module.bias is not None:
355-
bias_values[name] = module.bias.data.clone()
356-
357-
_, _, inputs = self.get_dummy_inputs(with_generator=False)
358-
359-
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
360-
logger.setLevel(logging.INFO)
361-
362-
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
363-
364-
denoiser_lora_config.lora_bias = False
365-
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
366-
lora_bias_false_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
367-
pipe.delete_adapters("adapter-1")
368-
369-
denoiser_lora_config.lora_bias = True
370-
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
371-
lora_bias_true_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
372-
373-
self.assertFalse(np.allclose(original_output, lora_bias_false_output, atol=1e-3, rtol=1e-3))
374-
self.assertFalse(np.allclose(original_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
375-
self.assertFalse(np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
376-
377-
# for now this is flux control lora specific but can be generalized later and added to ./utils.py
378-
def test_correct_lora_configs_with_different_ranks(self):
379-
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
343+
# Testing opposite direction where the LoRA params are zero-padded.
344+
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
380345
pipe = self.pipeline_class(**components)
381346
pipe = pipe.to(torch_device)
382347
pipe.set_progress_bar_config(disable=None)
383-
_, _, inputs = self.get_dummy_inputs(with_generator=False)
384-
385-
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
386-
387-
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
388-
lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
389-
pipe.transformer.delete_adapters("adapter-1")
390-
391-
# change the rank_pattern
392-
updated_rank = denoiser_lora_config.r * 2
393-
denoiser_lora_config.rank_pattern = {"single_transformer_blocks.0.attn.to_k": updated_rank}
394-
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
395-
assert pipe.transformer.peft_config["adapter-1"].rank_pattern == {
396-
"single_transformer_blocks.0.attn.to_k": updated_rank
348+
dummy_lora_A = torch.nn.Linear(1, rank, bias=False)
349+
dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
350+
lora_state_dict = {
351+
"transformer.x_embedder.lora_A.weight": dummy_lora_A.weight,
352+
"transformer.x_embedder.lora_B.weight": dummy_lora_B.weight,
397353
}
354+
with CaptureLogger(logger) as cap_logger:
355+
pipe.load_lora_weights(lora_state_dict, "adapter-1")
398356

399-
lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
400-
401-
self.assertTrue(not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3))
402-
self.assertTrue(not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3))
403-
pipe.transformer.delete_adapters("adapter-1")
404-
405-
# similarly change the alpha_pattern
406-
updated_alpha = denoiser_lora_config.lora_alpha * 2
407-
denoiser_lora_config.alpha_pattern = {"single_transformer_blocks.0.attn.to_k": updated_alpha}
408-
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
409-
assert pipe.transformer.peft_config["adapter-1"].alpha_pattern == {
410-
"single_transformer_blocks.0.attn.to_k": updated_alpha
411-
}
357+
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
412358

413-
lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0]
359+
lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
414360

415-
self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
416-
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
361+
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
362+
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
363+
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
364+
self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
417365

418-
def test_lora_expanding_shape_with_normal_lora(self):
419-
# This test checks if it works when a lora with expanded shapes (like control loras) but
420-
# another lora with correct shapes is loaded. The opposite direction isn't supported and is
421-
# tested with it.
366+
def test_normal_lora_with_expanded_lora_raises_error(self):
367+
# Test the following situation. Load a regular LoRA (such as the ones trained on Flux.1-Dev). And then
368+
# load shape expanded LoRA (such as Control LoRA).
422369
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
423370

424371
# Change the transformer config to mimic a real use case.

tests/lora/test_lora_layers_ltx_video.py

Lines changed: 2 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
import sys
1616
import unittest
1717

18-
import numpy as np
19-
import pytest
2018
import torch
2119
from transformers import AutoTokenizer, T5EncoderModel
2220

@@ -26,18 +24,12 @@
2624
LTXPipeline,
2725
LTXVideoTransformer3DModel,
2826
)
29-
from diffusers.utils.testing_utils import (
30-
floats_tensor,
31-
is_torch_version,
32-
require_peft_backend,
33-
skip_mps,
34-
torch_device,
35-
)
27+
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
3628

3729

3830
sys.path.append(".")
3931

40-
from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
32+
from utils import PeftLoraLoaderMixinTests # noqa: E402
4133

4234

4335
@require_peft_backend
@@ -107,41 +99,6 @@ def get_dummy_inputs(self, with_generator=True):
10799

108100
return noise, input_ids, pipeline_inputs
109101

110-
@skip_mps
111-
@pytest.mark.xfail(
112-
condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
113-
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
114-
strict=True,
115-
)
116-
def test_lora_fuse_nan(self):
117-
for scheduler_cls in self.scheduler_classes:
118-
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
119-
pipe = self.pipeline_class(**components)
120-
pipe = pipe.to(torch_device)
121-
pipe.set_progress_bar_config(disable=None)
122-
_, _, inputs = self.get_dummy_inputs(with_generator=False)
123-
124-
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
125-
126-
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
127-
128-
# corrupt one LoRA weight with `inf` values
129-
with torch.no_grad():
130-
pipe.transformer.transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf")
131-
132-
# with `safe_fusing=True` we should see an Error
133-
with self.assertRaises(ValueError):
134-
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
135-
136-
# without we should not see an error, but every image will be black
137-
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
138-
139-
out = pipe(
140-
"test", num_inference_steps=2, max_sequence_length=inputs["max_sequence_length"], output_type="np"
141-
)[0]
142-
143-
self.assertTrue(np.isnan(out).all())
144-
145102
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
146103
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
147104

tests/lora/utils.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1988,3 +1988,113 @@ def test_set_adapters_match_attention_kwargs(self):
19881988
np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
19891989
"Loading from saved checkpoints should give same results as set_adapters().",
19901990
)
1991+
1992+
@require_peft_version_greater("0.13.2")
1993+
def test_lora_B_bias(self):
1994+
# Currently, this test is only relevant for Flux Control LoRA as we are not
1995+
# aware of any other LoRA checkpoint that has its `lora_B` biases trained.
1996+
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
1997+
pipe = self.pipeline_class(**components)
1998+
pipe = pipe.to(torch_device)
1999+
pipe.set_progress_bar_config(disable=None)
2000+
2001+
# keep track of the bias values of the base layers to perform checks later.
2002+
bias_values = {}
2003+
denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer
2004+
for name, module in denoiser.named_modules():
2005+
if any(k in name for k in ["to_q", "to_k", "to_v", "to_out.0"]):
2006+
if module.bias is not None:
2007+
bias_values[name] = module.bias.data.clone()
2008+
2009+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
2010+
2011+
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
2012+
logger.setLevel(logging.INFO)
2013+
2014+
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
2015+
2016+
denoiser_lora_config.lora_bias = False
2017+
if self.unet_kwargs is not None:
2018+
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
2019+
else:
2020+
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
2021+
lora_bias_false_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
2022+
pipe.delete_adapters("adapter-1")
2023+
2024+
denoiser_lora_config.lora_bias = True
2025+
if self.unet_kwargs is not None:
2026+
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
2027+
else:
2028+
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
2029+
lora_bias_true_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
2030+
2031+
self.assertFalse(np.allclose(original_output, lora_bias_false_output, atol=1e-3, rtol=1e-3))
2032+
self.assertFalse(np.allclose(original_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
2033+
self.assertFalse(np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
2034+
2035+
def test_correct_lora_configs_with_different_ranks(self):
2036+
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
2037+
pipe = self.pipeline_class(**components)
2038+
pipe = pipe.to(torch_device)
2039+
pipe.set_progress_bar_config(disable=None)
2040+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
2041+
2042+
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
2043+
2044+
if self.unet_kwargs is not None:
2045+
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
2046+
else:
2047+
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
2048+
2049+
lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
2050+
2051+
if self.unet_kwargs is not None:
2052+
pipe.unet.delete_adapters("adapter-1")
2053+
else:
2054+
pipe.transformer.delete_adapters("adapter-1")
2055+
2056+
denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer
2057+
for name, _ in denoiser.named_modules():
2058+
if "to_k" in name and "attn" in name and "lora" not in name:
2059+
module_name_to_rank_update = name.replace(".base_layer.", ".")
2060+
break
2061+
2062+
# change the rank_pattern
2063+
updated_rank = denoiser_lora_config.r * 2
2064+
denoiser_lora_config.rank_pattern = {module_name_to_rank_update: updated_rank}
2065+
2066+
if self.unet_kwargs is not None:
2067+
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
2068+
updated_rank_pattern = pipe.unet.peft_config["adapter-1"].rank_pattern
2069+
else:
2070+
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
2071+
updated_rank_pattern = pipe.transformer.peft_config["adapter-1"].rank_pattern
2072+
2073+
self.assertTrue(updated_rank_pattern == {module_name_to_rank_update: updated_rank})
2074+
2075+
lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
2076+
self.assertTrue(not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3))
2077+
self.assertTrue(not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3))
2078+
2079+
if self.unet_kwargs is not None:
2080+
pipe.unet.delete_adapters("adapter-1")
2081+
else:
2082+
pipe.transformer.delete_adapters("adapter-1")
2083+
2084+
# similarly change the alpha_pattern
2085+
updated_alpha = denoiser_lora_config.lora_alpha * 2
2086+
denoiser_lora_config.alpha_pattern = {module_name_to_rank_update: updated_alpha}
2087+
if self.unet_kwargs is not None:
2088+
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
2089+
self.assertTrue(
2090+
pipe.unet.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha}
2091+
)
2092+
else:
2093+
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
2094+
self.assertTrue(
2095+
pipe.transformer.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha}
2096+
)
2097+
2098+
lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0]
2099+
self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
2100+
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))

0 commit comments

Comments
 (0)