Skip to content

Commit 84bc0e4

Browse files
committed
[LoRA] fix cross_attention_kwargs problems and tighten tests (#7388)
* debugging * let's see the numbers * let's see the numbers * let's see the numbers * restrict tolerance. * increase inference steps. * shallow copy of cross_attentionkwargs * remove print
1 parent 5584e1c commit 84bc0e4

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

src/diffusers/models/unets/unet_2d_condition.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1178,6 +1178,7 @@ def forward(
11781178
# we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
11791179
# to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
11801180
if cross_attention_kwargs is not None:
1181+
cross_attention_kwargs = cross_attention_kwargs.copy()
11811182
lora_scale = cross_attention_kwargs.pop("scale", 1.0)
11821183
else:
11831184
lora_scale = 1.0

tests/lora/test_lora_layers_peft.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def get_dummy_inputs(self, with_generator=True):
158158

159159
pipeline_inputs = {
160160
"prompt": "A painting of a squirrel eating a burger",
161-
"num_inference_steps": 2,
161+
"num_inference_steps": 5,
162162
"guidance_scale": 6.0,
163163
"output_type": "np",
164164
}
@@ -589,7 +589,7 @@ def test_simple_inference_with_text_unet_lora_and_scale(self):
589589
**inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5}
590590
).images
591591
self.assertTrue(
592-
not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
592+
not np.allclose(output_lora, output_lora_scale, atol=1e-4, rtol=1e-4),
593593
"Lora + scale should change the output",
594594
)
595595

@@ -1300,6 +1300,11 @@ def test_integration_logits_with_scale(self):
13001300
pipe.load_lora_weights(lora_id)
13011301
pipe = pipe.to("cuda")
13021302

1303+
self.assertTrue(
1304+
self.check_if_lora_correctly_set(pipe.unet),
1305+
"Lora not correctly set in UNet",
1306+
)
1307+
13031308
self.assertTrue(
13041309
self.check_if_lora_correctly_set(pipe.text_encoder),
13051310
"Lora not correctly set in text encoder 2",

0 commit comments

Comments
 (0)