Skip to content

Commit e2d2650

Browse files
committed
update
1 parent 1450c2a commit e2d2650

File tree

1 file changed

+45
-2
lines changed

1 file changed

+45
-2
lines changed

tests/lora/test_lora_layers_lumina2.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import sys
1616
import unittest
1717

18+
import numpy as np
1819
import torch
1920
from transformers import AutoTokenizer, GemmaForCausalLM
2021

@@ -24,12 +25,12 @@
2425
Lumina2Text2ImgPipeline,
2526
Lumina2Transformer2DModel,
2627
)
27-
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
28+
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, torch_device
2829

2930

3031
sys.path.append(".")
3132

32-
from utils import PeftLoraLoaderMixinTests # noqa: E402
33+
from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
3334

3435

3536
@require_peft_backend
@@ -130,3 +131,45 @@ def test_simple_inference_with_text_lora_fused(self):
130131
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
131132
def test_simple_inference_with_text_lora_save_load(self):
132133
pass
134+
135+
def test_lora_fuse_nan(self):
136+
for scheduler_cls in self.scheduler_classes:
137+
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
138+
pipe = self.pipeline_class(**components)
139+
pipe = pipe.to(torch_device)
140+
pipe.set_progress_bar_config(disable=None)
141+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
142+
143+
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
144+
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
145+
self.assertTrue(
146+
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
147+
)
148+
149+
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
150+
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
151+
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
152+
153+
# corrupt one LoRA weight with `inf` values
154+
with torch.no_grad():
155+
if self.unet_kwargs:
156+
pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A[
157+
"adapter-1"
158+
].weight += float("inf")
159+
else:
160+
named_modules = [name for name, _ in pipe.transformer.named_modules()]
161+
has_attn1 = any("attn1" in name for name in named_modules)
162+
if has_attn1:
163+
pipe.transformer.layers[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf")
164+
else:
165+
pipe.transformer.layers[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
166+
167+
# with `safe_fusing=True` we should see an Error
168+
with self.assertRaises(ValueError):
169+
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
170+
171+
# without we should not see an error, but every image will be black
172+
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
173+
out = pipe(**inputs)[0]
174+
175+
self.assertTrue(np.isnan(out).all())

0 commit comments

Comments
 (0)