Skip to content

Commit 98d20a4

Browse files
committed
_XFAIL_GLOO_WINDOWS
1 parent 105210f commit 98d20a4

File tree

6 files changed

+19
-14
lines changed

6 files changed

+19
-14
lines changed

src/lightning/fabric/utilities/imports.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,5 +35,5 @@
3535
_TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0")
3636
_TORCH_GREATER_EQUAL_2_4_1 = compare_version("torch", operator.ge, "2.4.1")
3737
_TORCH_LESS_EQUAL_2_6 = compare_version("torch", operator.le, "2.6.0")
38-
38+
_TORCHMETRICS_GREATER_EQUAL_1_0_0 = compare_version("torchmetrics", operator.ge, "1.0.0")
3939
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)

src/lightning/fabric/utilities/spike.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,16 @@
11
import json
2-
import operator
32
import os
43
import warnings
54
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
65

76
import torch
8-
from lightning_utilities.core.imports import compare_version
97

8+
from lightning.fabric.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_1_0_0
109
from lightning.fabric.utilities.types import _PATH
1110

1211
if TYPE_CHECKING:
1312
from lightning.fabric.fabric import Fabric
1413

15-
_TORCHMETRICS_GREATER_EQUAL_1_0_0 = compare_version("torchmetrics", operator.ge, "1.0.0")
16-
1714

1815
class SpikeDetection:
1916
"""Spike Detection Callback.

src/lightning/pytorch/trainer/connectors/logger_connector/result.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
from lightning.fabric.utilities import move_data_to_device
2626
from lightning.fabric.utilities.apply_func import convert_tensors_to_scalars
2727
from lightning.fabric.utilities.distributed import _distributed_is_initialized
28+
from lightning.fabric.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_1_0_0
2829
from lightning.pytorch.utilities.data import extract_batch_size
2930
from lightning.pytorch.utilities.exceptions import MisconfigurationException
30-
from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_1_0_0
3131
from lightning.pytorch.utilities.memory import recursive_detach
3232
from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_warn
3333
from lightning.pytorch.utilities.warnings import PossibleUserWarning

src/lightning/pytorch/utilities/imports.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
_TORCHMETRICS_GREATER_EQUAL_0_9_1 = RequirementCache("torchmetrics>=0.9.1")
2626
_TORCHMETRICS_GREATER_EQUAL_0_11 = RequirementCache("torchmetrics>=0.11.0") # using new API with task
2727
_TORCHMETRICS_GREATER_EQUAL_1_0_0 = RequirementCache("torchmetrics>=1.0.0")
28+
_TORCH_GREATER_EQUAL_2_8 = RequirementCache("torch>=2.8.0")
2829

2930
_OMEGACONF_AVAILABLE = package_available("omegaconf")
3031
_TORCHVISION_AVAILABLE = RequirementCache("torchvision")

tests/tests_fabric/utilities/test_spike.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import torch
55

66
from lightning.fabric import Fabric
7-
from lightning.fabric.utilities.spike import _TORCHMETRICS_GREATER_EQUAL_1_0_0, SpikeDetection, TrainingSpikeException
7+
from lightning.fabric.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_1_0_0
8+
from lightning.fabric.utilities.spike import SpikeDetection, TrainingSpikeException
89
from tests_fabric.helpers.runif import RunIf
910

1011

tests/tests_pytorch/callbacks/test_spike.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
import pytest
44
import torch
55

6-
from lightning.fabric.utilities.spike import _TORCHMETRICS_GREATER_EQUAL_1_0_0, TrainingSpikeException
6+
from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCHMETRICS_GREATER_EQUAL_1_0_0
7+
from lightning.fabric.utilities.spike import TrainingSpikeException
78
from lightning.pytorch import LightningModule, Trainer
89
from lightning.pytorch.callbacks.spike import SpikeDetection
10+
from lightning.pytorch.utilities.imports import _TORCH_GREATER_EQUAL_2_8
911
from tests_pytorch.helpers.runif import RunIf
1012

1113

@@ -47,6 +49,10 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
4749
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
4850

4951

52+
# todo: RuntimeError: makeDeviceForHostname(): unsupported gloo device
53+
_XFAIL_GLOO_WINDOWS = pytest.mark.xfail(RuntimeError, strict=True, condition=(_IS_WINDOWS and _TORCH_GREATER_EQUAL_2_8))
54+
55+
5056
@pytest.mark.flaky(max_runs=3)
5157
@pytest.mark.parametrize(
5258
("global_rank_spike", "num_devices", "spike_value", "finite_only"),
@@ -55,12 +61,12 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
5561
[
5662
pytest.param(0, 1, None, True),
5763
pytest.param(0, 1, None, False),
58-
pytest.param(0, 1, float("inf"), True),
59-
pytest.param(0, 1, float("inf"), False),
60-
pytest.param(0, 1, float("-inf"), True),
61-
pytest.param(0, 1, float("-inf"), False),
62-
pytest.param(0, 1, float("NaN"), True),
63-
pytest.param(0, 1, float("NaN"), False),
64+
pytest.param(0, 1, float("inf"), True, marks=_XFAIL_GLOO_WINDOWS),
65+
pytest.param(0, 1, float("inf"), False, marks=_XFAIL_GLOO_WINDOWS),
66+
pytest.param(0, 1, float("-inf"), True, marks=_XFAIL_GLOO_WINDOWS),
67+
pytest.param(0, 1, float("-inf"), False, marks=_XFAIL_GLOO_WINDOWS),
68+
pytest.param(0, 1, float("NaN"), True, marks=_XFAIL_GLOO_WINDOWS),
69+
pytest.param(0, 1, float("NaN"), False, marks=_XFAIL_GLOO_WINDOWS),
6470
pytest.param(0, 2, None, True, marks=RunIf(linux_only=True)),
6571
pytest.param(0, 2, None, False, marks=RunIf(linux_only=True)),
6672
pytest.param(1, 2, None, True, marks=RunIf(linux_only=True)),

0 commit comments

Comments
 (0)