Skip to content

Commit dbc8427

Browse files
committed
Fix test_lora_fuse_nan.
1 parent a5b78d1 commit dbc8427

File tree

1 file changed

+20
-12
lines changed

1 file changed

+20
-12
lines changed

tests/lora/utils.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1563,7 +1563,6 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
15631563
"output with no lora and output with lora disabled should give same results",
15641564
)
15651565

1566-
@skip_mps
15671566
@pytest.mark.xfail(
15681567
condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
15691568
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
@@ -1595,17 +1594,26 @@ def test_lora_fuse_nan(self):
15951594
].weight += float("inf")
15961595
else:
15971596
named_modules = [name for name, _ in pipe.transformer.named_modules()]
1598-
tower_name = (
1599-
"transformer_blocks"
1600-
if any(name == "transformer_blocks" for name in named_modules)
1601-
else "blocks"
1602-
)
1603-
transformer_tower = getattr(pipe.transformer, tower_name)
1604-
has_attn1 = any("attn1" in name for name in named_modules)
1605-
if has_attn1:
1606-
transformer_tower[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf")
1607-
else:
1608-
transformer_tower[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
1597+
possible_tower_names = [
1598+
"transformer_blocks",
1599+
"blocks",
1600+
"joint_transformer_blocks",
1601+
"single_transformer_blocks",
1602+
]
1603+
filtered_tower_names = [
1604+
tower_name for tower_name in possible_tower_names if hasattr(pipe.transformer, tower_name)
1605+
]
1606+
if len(filtered_tower_names) == 0:
1607+
pytest.xfail(
1608+
reason=f"`pipe.transformer` didn't have any of the following attributes: {possible_tower_names}."
1609+
)
1610+
for tower_name in filtered_tower_names:
1611+
transformer_tower = getattr(pipe.transformer, tower_name)
1612+
has_attn1 = any("attn1" in name for name in named_modules)
1613+
if has_attn1:
1614+
transformer_tower[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf")
1615+
else:
1616+
transformer_tower[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
16091617

16101618
# with `safe_fusing=True` we should see an Error
16111619
with self.assertRaises(ValueError):

0 commit comments

Comments
 (0)