@@ -1563,7 +1563,6 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
15631563 "output with no lora and output with lora disabled should give same results" ,
15641564 )
15651565
1566- @skip_mps
15671566 @pytest .mark .xfail (
15681567 condition = torch .device (torch_device ).type == "cpu" and is_torch_version (">=" , "2.5" ),
15691568 reason = "Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1." ,
@@ -1595,17 +1594,26 @@ def test_lora_fuse_nan(self):
15951594 ].weight += float ("inf" )
15961595 else :
15971596 named_modules = [name for name , _ in pipe .transformer .named_modules ()]
1598- tower_name = (
1599- "transformer_blocks"
1600- if any (name == "transformer_blocks" for name in named_modules )
1601- else "blocks"
1602- )
1603- transformer_tower = getattr (pipe .transformer , tower_name )
1604- has_attn1 = any ("attn1" in name for name in named_modules )
1605- if has_attn1 :
1606- transformer_tower [0 ].attn1 .to_q .lora_A ["adapter-1" ].weight += float ("inf" )
1607- else :
1608- transformer_tower [0 ].attn .to_q .lora_A ["adapter-1" ].weight += float ("inf" )
1597+ possible_tower_names = [
1598+ "transformer_blocks" ,
1599+ "blocks" ,
1600+ "joint_transformer_blocks" ,
1601+ "single_transformer_blocks" ,
1602+ ]
1603+ filtered_tower_names = [
1604+ tower_name for tower_name in possible_tower_names if hasattr (pipe .transformer , tower_name )
1605+ ]
1606+ if len (filtered_tower_names ) == 0 :
1607+ pytest .xfail (
1608+ reason = f"`pipe.transformer` didn't have any of the following attributes: { possible_tower_names } ."
1609+ )
1610+ for tower_name in filtered_tower_names :
1611+ transformer_tower = getattr (pipe .transformer , tower_name )
1612+ has_attn1 = any ("attn1" in name for name in named_modules )
1613+ if has_attn1 :
1614+ transformer_tower [0 ].attn1 .to_q .lora_A ["adapter-1" ].weight += float ("inf" )
1615+ else :
1616+ transformer_tower [0 ].attn .to_q .lora_A ["adapter-1" ].weight += float ("inf" )
16091617
16101618 # with `safe_fusing=True` we should see an Error
16111619 with self .assertRaises (ValueError ):
0 commit comments