Skip to content

Commit 9fd5c9a

Browse files
committed
updatye
1 parent 395f9d9 commit 9fd5c9a

File tree

2 files changed

+22
-10
lines changed

2 files changed

+22
-10
lines changed

tests/lora/test_lora_layers_sdxl.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import numpy as np
2323
import torch
2424
from packaging import version
25+
from parameterized import parameterized
2526
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
2627

2728
from diffusers import (
@@ -117,17 +118,28 @@ def tearDown(self):
117118
def test_multiple_wrong_adapter_name_raises_error(self):
118119
super().test_multiple_wrong_adapter_name_raises_error()
119120

120-
def test_simple_inference_with_text_denoiser_lora_unfused(self):
121+
@parameterized.expand(
122+
[
123+
# Test actions on text_encoder LoRA only
124+
("fused", "text_encoder_only"),
125+
("unloaded", "text_encoder_only"),
126+
("save_load", "text_encoder_only"),
127+
# Test actions on both text_encoder and denoiser LoRA
128+
("fused", "text_and_denoiser"),
129+
("unloaded", "text_and_denoiser"),
130+
("unfused", "text_and_denoiser"),
131+
("save_load", "text_and_denoiser"),
132+
]
133+
)
134+
def test_lora_actions(self, action, components_to_add):
121135
if torch.cuda.is_available():
122136
expected_atol = 9e-2
123137
expected_rtol = 9e-2
124138
else:
125139
expected_atol = 1e-3
126140
expected_rtol = 1e-3
127141

128-
super().test_simple_inference_with_text_denoiser_lora_unfused(
129-
expected_atol=expected_atol, expected_rtol=expected_rtol
130-
)
142+
super().test_lora_actions(expected_atol=expected_atol, expected_rtol=expected_rtol)
131143

132144
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
133145
if torch.cuda.is_available():

tests/lora/utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def test_simple_inference_with_partial_text_lora(self):
292292
"Removing adapters should change the output",
293293
)
294294

295-
def _test_lora_actions(self, action, lora_components_to_add, expected_atol=1e-3):
295+
def _test_lora_actions(self, action, lora_components_to_add, expected_atol=1e-3, expected_rtol=1e-3):
296296
"""
297297
A unified test for various LoRA actions (fusing, unloading, saving/loading, etc.)
298298
on different combinations of model components.
@@ -321,7 +321,7 @@ def _test_lora_actions(self, action, lora_components_to_add, expected_atol=1e-3)
321321
)
322322

323323
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
324-
self.assertTrue(not np.allclose(output_lora, output_no_lora, atol=expected_atol, rtol=1e-3))
324+
self.assertTrue(not np.allclose(output_lora, output_no_lora, atol=expected_atol, rtol=expected_rtol))
325325

326326
# 3. Perform the specified action and assert the outcome
327327
if action == "fused":
@@ -330,7 +330,7 @@ def _test_lora_actions(self, action, lora_components_to_add, expected_atol=1e-3)
330330
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
331331
output_after_action = pipe(**inputs, generator=torch.manual_seed(0))[0]
332332
self.assertTrue(
333-
not np.allclose(output_after_action, output_no_lora, atol=expected_atol, rtol=1e-3),
333+
not np.allclose(output_after_action, output_no_lora, atol=expected_atol, rtol=expected_rtol),
334334
"Fused LoRA should produce a different output from the base model.",
335335
)
336336

@@ -342,7 +342,7 @@ def _test_lora_actions(self, action, lora_components_to_add, expected_atol=1e-3)
342342
)
343343
output_after_action = pipe(**inputs, generator=torch.manual_seed(0))[0]
344344
self.assertTrue(
345-
np.allclose(output_after_action, output_no_lora, atol=expected_atol, rtol=1e-3),
345+
np.allclose(output_after_action, output_no_lora, atol=expected_atol, rtol=expected_rtol),
346346
"Output after unloading LoRA should match the original output.",
347347
)
348348

@@ -358,7 +358,7 @@ def _test_lora_actions(self, action, lora_components_to_add, expected_atol=1e-3)
358358
output_unfused = pipe(**inputs, generator=torch.manual_seed(0))[0]
359359

360360
self.assertTrue(
361-
np.allclose(output_fused, output_unfused, atol=expected_atol, rtol=1e-3),
361+
np.allclose(output_fused, output_unfused, atol=expected_atol, rtol=expected_rtol),
362362
"Output after unfusing should match the fused output.",
363363
)
364364

@@ -382,7 +382,7 @@ def _test_lora_actions(self, action, lora_components_to_add, expected_atol=1e-3)
382382

383383
output_after_action = pipe(**inputs, generator=torch.manual_seed(0))[0]
384384
self.assertTrue(
385-
np.allclose(output_lora, output_after_action, atol=expected_atol, rtol=1e-3),
385+
np.allclose(output_lora, output_after_action, atol=expected_atol, rtol=expected_rtol),
386386
"Loading from a saved checkpoint should yield the same result.",
387387
)
388388

0 commit comments

Comments
 (0)