Skip to content

Commit 27ad9e9

Browse files
authored
xfail collective tests (#18779)
1 parent c39f680 commit 27ad9e9

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

tests/tests_fabric/plugins/collectives/test_torch_collective.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from lightning.fabric.plugins.environments import LightningEnvironment
1212
from lightning.fabric.strategies.ddp import DDPStrategy
1313
from lightning.fabric.strategies.launchers.multiprocessing import _MultiProcessingLauncher
14-
from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_13
14+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_13
1515

1616
from tests_fabric.helpers.runif import RunIf
1717

@@ -231,9 +231,7 @@ def _test_distributed_collectives_fn(strategy, collective):
231231

232232

233233
@skip_distributed_unavailable
234-
@pytest.mark.parametrize("n", [1, 2])
235-
@RunIf(skip_windows=True)
236-
@mock.patch.dict(os.environ, os.environ.copy(), clear=True) # sets CUDA_MODULE_LOADING in torch==1.13
234+
@pytest.mark.parametrize("n", [1, pytest.param(2, marks=pytest.mark.xfail(raises=TimeoutError, strict=False))])
237235
def test_collectives_distributed(n):
238236
collective_launch(_test_distributed_collectives_fn, [torch.device("cpu")] * n)
239237

@@ -268,8 +266,8 @@ def _test_two_groups(strategy, left_collective, right_collective):
268266

269267

270268
@skip_distributed_unavailable
271-
@pytest.mark.skipif(_IS_WINDOWS, reason="strange TimeOut exception") # Todo
272-
@pytest.mark.xfail(strict=False, reason="TODO(carmocca): causing hangs in CI")
269+
@RunIf(skip_windows=True) # unhandled timeouts
270+
@pytest.mark.xfail(raises=TimeoutError, strict=False)
273271
def test_two_groups():
274272
collective_launch(_test_two_groups, [torch.device("cpu")] * 3, num_groups=2)
275273

@@ -285,8 +283,7 @@ def _test_default_process_group(strategy, *collectives):
285283

286284

287285
@skip_distributed_unavailable
288-
@RunIf(skip_windows=True)
289-
@mock.patch.dict(os.environ, os.environ.copy(), clear=True) # sets CUDA_MODULE_LOADING in torch==1.13
286+
@RunIf(skip_windows=True) # unhandled timeouts
290287
def test_default_process_group():
291288
collective_launch(_test_default_process_group, [torch.device("cpu")] * 3, num_groups=2)
292289

0 commit comments

Comments
 (0)