Skip to content

Commit 514848f

Browse files
committed
_xfail_gloo_windows
1 parent e5f2120 commit 514848f

File tree

9 files changed

+27
-20
lines changed

9 files changed

+27
-20
lines changed

src/lightning/pytorch/utilities/imports.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
_TORCHMETRICS_GREATER_EQUAL_0_9_1 = RequirementCache("torchmetrics>=0.9.1")
2626
_TORCHMETRICS_GREATER_EQUAL_0_11 = RequirementCache("torchmetrics>=0.11.0") # using new API with task
2727
_TORCHMETRICS_GREATER_EQUAL_1_0_0 = RequirementCache("torchmetrics>=1.0.0")
28-
_TORCH_GREATER_EQUAL_2_8 = RequirementCache("torch>=2.8.0")
28+
_TORCH_EQUAL_2_8 = RequirementCache("torch>=2.8.0,<2.9.0")
2929

3030
_OMEGACONF_AVAILABLE = package_available("omegaconf")
3131
_TORCHVISION_AVAILABLE = RequirementCache("torchvision")

tests/tests_pytorch/callbacks/test_spike.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from lightning.fabric.utilities.spike import TrainingSpikeException
88
from lightning.pytorch import LightningModule, Trainer
99
from lightning.pytorch.callbacks.spike import SpikeDetection
10-
from tests_pytorch.helpers.runif import _XFAIL_GLOO_WINDOWS, RunIf
10+
from tests_pytorch.helpers.runif import RunIf, _xfail_gloo_windows
1111

1212

1313
class IdentityModule(LightningModule):
@@ -54,14 +54,14 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
5454
# NOTE FOR ALL FOLLOWING TESTS:
5555
# adding run on linux only because multiprocessing on other platforms takes forever
5656
[
57-
pytest.param(0, 1, None, True),
58-
pytest.param(0, 1, None, False),
59-
pytest.param(0, 1, float("inf"), True, marks=_XFAIL_GLOO_WINDOWS),
60-
pytest.param(0, 1, float("inf"), False, marks=_XFAIL_GLOO_WINDOWS),
61-
pytest.param(0, 1, float("-inf"), True, marks=_XFAIL_GLOO_WINDOWS),
62-
pytest.param(0, 1, float("-inf"), False, marks=_XFAIL_GLOO_WINDOWS),
63-
pytest.param(0, 1, float("NaN"), True, marks=_XFAIL_GLOO_WINDOWS),
64-
pytest.param(0, 1, float("NaN"), False, marks=_XFAIL_GLOO_WINDOWS),
57+
pytest.param(0, 1, None, True, marks=_xfail_gloo_windows),
58+
pytest.param(0, 1, None, False, marks=_xfail_gloo_windows),
59+
pytest.param(0, 1, float("inf"), True, marks=_xfail_gloo_windows),
60+
pytest.param(0, 1, float("inf"), False, marks=_xfail_gloo_windows),
61+
pytest.param(0, 1, float("-inf"), True, marks=_xfail_gloo_windows),
62+
pytest.param(0, 1, float("-inf"), False, marks=_xfail_gloo_windows),
63+
pytest.param(0, 1, float("NaN"), True, marks=_xfail_gloo_windows),
64+
pytest.param(0, 1, float("NaN"), False, marks=_xfail_gloo_windows),
6565
pytest.param(0, 2, None, True, marks=RunIf(linux_only=True)),
6666
pytest.param(0, 2, None, False, marks=RunIf(linux_only=True)),
6767
pytest.param(1, 2, None, True, marks=RunIf(linux_only=True)),

tests/tests_pytorch/helpers/runif.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import pytest
1515

1616
from lightning.fabric.utilities.imports import _IS_WINDOWS
17-
from lightning.pytorch.utilities.imports import _TORCH_GREATER_EQUAL_2_8
17+
from lightning.pytorch.utilities.imports import _TORCH_EQUAL_2_8
1818
from lightning.pytorch.utilities.testing import _runif_reasons
1919

2020

@@ -24,9 +24,9 @@ def RunIf(**kwargs):
2424

2525

2626
# todo: RuntimeError: makeDeviceForHostname(): unsupported gloo device
27-
_XFAIL_GLOO_WINDOWS = pytest.mark.xfail(
27+
_xfail_gloo_windows = pytest.mark.xfail(
2828
RuntimeError,
2929
strict=True,
30-
condition=(_IS_WINDOWS and _TORCH_GREATER_EQUAL_2_8),
30+
condition=(_IS_WINDOWS and _TORCH_EQUAL_2_8),
3131
reason="makeDeviceForHostname(): unsupported gloo device",
3232
)

tests/tests_pytorch/loops/test_prediction_loop.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from lightning.pytorch import LightningModule, Trainer
2020
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
2121
from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper
22+
from tests_pytorch.helpers.runif import _xfail_gloo_windows
2223

2324

2425
def test_prediction_loop_stores_predictions(tmp_path):
@@ -51,6 +52,7 @@ def predict_step(self, batch, batch_idx):
5152
assert trainer.predict_loop.predictions == []
5253

5354

55+
@_xfail_gloo_windows
5456
@pytest.mark.parametrize("use_distributed_sampler", [False, True])
5557
def test_prediction_loop_batch_sampler_set_epoch_called(tmp_path, use_distributed_sampler):
5658
"""Tests that set_epoch is called on the dataloader's batch sampler (if any) during prediction."""

tests/tests_pytorch/models/test_amp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from lightning.fabric.plugins.environments import SLURMEnvironment
2323
from lightning.pytorch import Trainer
2424
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
25-
from tests_pytorch.helpers.runif import RunIf
25+
from tests_pytorch.helpers.runif import RunIf, _xfail_gloo_windows
2626

2727

2828
class AMPTestModel(BoringModel):
@@ -53,7 +53,7 @@ def _assert_autocast_enabled(self):
5353
[
5454
("single_device", "16-mixed", 1),
5555
("single_device", "bf16-mixed", 1),
56-
("ddp_spawn", "16-mixed", 2),
56+
pytest.param("ddp_spawn", "16-mixed", 2, marks=_xfail_gloo_windows),
5757
pytest.param("ddp_spawn", "bf16-mixed", 2, marks=RunIf(skip_windows=True)),
5858
],
5959
)

tests/tests_pytorch/serve/test_servable_module_validator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from lightning.pytorch import Trainer
66
from lightning.pytorch.demos.boring_classes import BoringModel
77
from lightning.pytorch.serve.servable_module_validator import ServableModule, ServableModuleValidator
8+
from tests_pytorch.helpers.runif import _xfail_gloo_windows
89

910

1011
class ServableBoringModel(BoringModel, ServableModule):
@@ -28,13 +29,14 @@ def configure_response(self):
2829
return {"output": [0, 1]}
2930

3031

31-
@pytest.mark.xfail(strict=False, reason="test is too flaky in CI") # todo
32+
@pytest.mark.flaky(reruns=3)
3233
def test_servable_module_validator():
3334
model = ServableBoringModel()
3435
callback = ServableModuleValidator()
3536
callback.on_train_start(Trainer(accelerator="cpu"), model)
3637

3738

39+
@_xfail_gloo_windows
3840
@pytest.mark.flaky(reruns=3)
3941
def test_servable_module_validator_with_trainer(tmp_path, mps_count_0):
4042
callback = ServableModuleValidator()

tests/tests_pytorch/strategies/launchers/test_multiprocessing.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from lightning.pytorch.strategies import DDPStrategy
2626
from lightning.pytorch.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher
2727
from lightning.pytorch.trainer.states import TrainerFn
28-
from tests_pytorch.helpers.runif import RunIf
28+
from tests_pytorch.helpers.runif import RunIf, _xfail_gloo_windows
2929

3030

3131
@mock.patch("lightning.pytorch.strategies.launchers.multiprocessing.mp.get_all_start_methods", return_value=[])
@@ -194,6 +194,7 @@ def on_fit_start(self) -> None:
194194
assert torch.equal(self.layer.weight.data, self.tied_layer.weight.data)
195195

196196

197+
@_xfail_gloo_windows
197198
def test_memory_sharing_disabled(tmp_path):
198199
"""Test that the multiprocessing launcher disables memory sharing on model parameters and buffers to avoid race
199200
conditions on model updates."""
@@ -219,6 +220,7 @@ def test_check_for_missing_main_guard():
219220
launcher.launch(function=Mock())
220221

221222

223+
@_xfail_gloo_windows
222224
def test_fit_twice_raises(mps_count_0):
223225
model = BoringModel()
224226
trainer = Trainer(

tests/tests_pytorch/trainer/connectors/test_data_connector.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from lightning.pytorch.utilities.combined_loader import CombinedLoader
3838
from lightning.pytorch.utilities.data import _update_dataloader
3939
from lightning.pytorch.utilities.exceptions import MisconfigurationException
40-
from tests_pytorch.helpers.runif import RunIf
40+
from tests_pytorch.helpers.runif import RunIf, _xfail_gloo_windows
4141

4242

4343
@RunIf(skip_windows=True)
@@ -123,6 +123,7 @@ def on_train_end(self):
123123
self.ctx.__exit__(None, None, None)
124124

125125

126+
@_xfail_gloo_windows
126127
@pytest.mark.parametrize("num_workers", [0, 1, 2])
127128
def test_dataloader_persistent_workers_performance_warning(num_workers, tmp_path):
128129
"""Test that when the multiprocessing start-method is 'spawn', we recommend setting `persistent_workers=True`."""

tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from lightning.pytorch.trainer.states import RunningStage
3636
from lightning.pytorch.utilities.exceptions import MisconfigurationException
3737
from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_11 as _TM_GE_0_11
38-
from tests_pytorch.helpers.runif import RunIf
38+
from tests_pytorch.helpers.runif import RunIf, _xfail_gloo_windows
3939

4040

4141
def test__training_step__log(tmp_path):
@@ -346,7 +346,7 @@ def validation_step(self, batch, batch_idx):
346346
("devices", "accelerator"),
347347
[
348348
(1, "cpu"),
349-
(2, "cpu"),
349+
pytest.param(2, "cpu", marks=_xfail_gloo_windows),
350350
pytest.param(2, "gpu", marks=RunIf(min_cuda_gpus=2)),
351351
],
352352
)

0 commit comments

Comments
 (0)