Skip to content

Commit 8c80718

Browse files
committed
fixture
1 parent 78b0f52 commit 8c80718

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

tests/models/test_attention_processor.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,13 @@ def test_only_cross_attention(self):
8181

8282

8383
class DeprecatedAttentionBlockTests(unittest.TestCase):
84+
@pytest.fixture(scope="session")
85+
def is_dist_enabled(pytestconfig):
86+
return pytestconfig.getoption("dist") == "loadfile"
87+
8488
@pytest.mark.xfail(
85-
condition=torch.device(torch_device).type == "cuda",
86-
reason="Test currently fails on our GPU CI because of `disfile`.",
89+
condition=torch.device(torch_device).type == "cuda" and is_dist_enabled,
90+
reason="Test currently fails on our GPU CI because of `loadfile`. Note that it only fails when the tests are distributed from `pytest ... tests/models`. If the tests are run individually, even with `loadfile` it won't fail.",
8791
strict=True,
8892
)
8993
def test_conversion_when_using_device_map(self):

0 commit comments

Comments
 (0)