Skip to content

Commit 105210f

Browse files
committed
RunIf(linux_only=True)
1 parent 636be26 commit 105210f

File tree

4 files changed

+46
-292
lines changed

4 files changed

+46
-292
lines changed

src/lightning/fabric/utilities/testing/_runif.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def _runif_reasons(
4040
standalone: bool = False,
4141
deepspeed: bool = False,
4242
dynamo: bool = False,
43+
linux_only: bool = False,
4344
) -> tuple[list[str], dict[str, bool]]:
4445
"""Construct reasons for pytest skipif.
4546
@@ -123,4 +124,7 @@ def _runif_reasons(
123124
if not is_dynamo_supported():
124125
reasons.append("torch.dynamo")
125126

127+
if linux_only and sys.platform != "linux":
128+
reasons.append("only linux")
129+
126130
return reasons, kwargs

src/lightning/pytorch/utilities/testing/_runif.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from lightning_utilities.core.imports import RequirementCache
1717

18-
from lightning.fabric.utilities.testing import _runif_reasons as fabric_run_if
18+
from lightning.fabric.utilities.testing import _runif_reasons as _fabric_run_if
1919
from lightning.pytorch.accelerators.cpu import _PSUTIL_AVAILABLE
2020
from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE
2121
from lightning.pytorch.core.module import _ONNX_AVAILABLE
@@ -42,6 +42,7 @@ def _runif_reasons(
4242
psutil: bool = False,
4343
sklearn: bool = False,
4444
onnx: bool = False,
45+
linux_only: bool = False,
4546
) -> tuple[list[str], dict[str, bool]]:
4647
"""Construct reasons for pytest skipif.
4748
@@ -67,7 +68,7 @@ def _runif_reasons(
6768
6869
"""
6970

70-
reasons, kwargs = fabric_run_if(
71+
reasons, kwargs = _fabric_run_if(
7172
min_cuda_gpus=min_cuda_gpus,
7273
min_torch=min_torch,
7374
max_torch=max_torch,
@@ -79,6 +80,7 @@ def _runif_reasons(
7980
standalone=standalone,
8081
deepspeed=deepspeed,
8182
dynamo=dynamo,
83+
linux_only=linux_only,
8284
)
8385

8486
if rich and not _RICH_AVAILABLE:

tests/tests_fabric/utilities/test_spike.py

Lines changed: 19 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import contextlib
2-
import sys
32

43
import pytest
54
import torch
65

76
from lightning.fabric import Fabric
87
from lightning.fabric.utilities.spike import _TORCHMETRICS_GREATER_EQUAL_1_0_0, SpikeDetection, TrainingSpikeException
8+
from tests_fabric.helpers.runif import RunIf
99

1010

1111
def spike_detection_test(fabric, global_rank_spike, spike_value, should_raise):
@@ -32,6 +32,8 @@ def spike_detection_test(fabric, global_rank_spike, spike_value, should_raise):
3232
@pytest.mark.flaky(max_runs=3)
3333
@pytest.mark.parametrize(
3434
("global_rank_spike", "num_devices", "spike_value", "finite_only"),
35+
# NOTE FOR ALL FOLLOWING TESTS:
36+
# adding run on linux only because multiprocessing on other platforms takes forever
3537
[
3638
pytest.param(0, 1, None, True),
3739
pytest.param(0, 1, None, False),
@@ -41,150 +43,22 @@ def spike_detection_test(fabric, global_rank_spike, spike_value, should_raise):
4143
pytest.param(0, 1, float("-inf"), False),
4244
pytest.param(0, 1, float("NaN"), True),
4345
pytest.param(0, 1, float("NaN"), False),
44-
pytest.param(
45-
0,
46-
2,
47-
None,
48-
True,
49-
marks=pytest.mark.skipif(
50-
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
51-
),
52-
),
53-
pytest.param(
54-
0,
55-
2,
56-
None,
57-
False,
58-
marks=pytest.mark.skipif(
59-
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
60-
),
61-
),
62-
pytest.param(
63-
1,
64-
2,
65-
None,
66-
True,
67-
marks=pytest.mark.skipif(
68-
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
69-
),
70-
),
71-
pytest.param(
72-
1,
73-
2,
74-
None,
75-
False,
76-
marks=pytest.mark.skipif(
77-
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
78-
),
79-
),
80-
pytest.param(
81-
0,
82-
2,
83-
float("inf"),
84-
True,
85-
marks=pytest.mark.skipif(
86-
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
87-
),
88-
),
89-
pytest.param(
90-
0,
91-
2,
92-
float("inf"),
93-
False,
94-
marks=pytest.mark.skipif(
95-
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
96-
),
97-
),
98-
pytest.param(
99-
1,
100-
2,
101-
float("inf"),
102-
True,
103-
marks=pytest.mark.skipif(
104-
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
105-
),
106-
),
107-
pytest.param(
108-
1,
109-
2,
110-
float("inf"),
111-
False,
112-
marks=pytest.mark.skipif(
113-
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
114-
),
115-
),
116-
pytest.param(
117-
0,
118-
2,
119-
float("-inf"),
120-
True,
121-
marks=pytest.mark.skipif(
122-
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
123-
),
124-
),
125-
pytest.param(
126-
0,
127-
2,
128-
float("-inf"),
129-
False,
130-
marks=pytest.mark.skipif(
131-
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
132-
),
133-
),
134-
pytest.param(
135-
1,
136-
2,
137-
float("-inf"),
138-
True,
139-
marks=pytest.mark.skipif(
140-
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
141-
),
142-
),
143-
pytest.param(
144-
1,
145-
2,
146-
float("-inf"),
147-
False,
148-
marks=pytest.mark.skipif(
149-
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
150-
),
151-
),
152-
pytest.param(
153-
0,
154-
2,
155-
float("NaN"),
156-
True,
157-
marks=pytest.mark.skipif(
158-
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
159-
),
160-
),
161-
pytest.param(
162-
0,
163-
2,
164-
float("NaN"),
165-
False,
166-
marks=pytest.mark.skipif(
167-
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
168-
),
169-
),
170-
pytest.param(
171-
1,
172-
2,
173-
float("NaN"),
174-
True,
175-
marks=pytest.mark.skipif(
176-
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
177-
),
178-
),
179-
pytest.param(
180-
1,
181-
2,
182-
float("NaN"),
183-
False,
184-
marks=pytest.mark.skipif(
185-
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
186-
),
187-
),
46+
pytest.param(0, 2, None, True, marks=RunIf(linux_only=True)),
47+
pytest.param(0, 2, None, False, marks=RunIf(linux_only=True)),
48+
pytest.param(1, 2, None, True, marks=RunIf(linux_only=True)),
49+
pytest.param(1, 2, None, False, marks=RunIf(linux_only=True)),
50+
pytest.param(0, 2, float("inf"), True, marks=RunIf(linux_only=True)),
51+
pytest.param(0, 2, float("inf"), False, marks=RunIf(linux_only=True)),
52+
pytest.param(1, 2, float("inf"), True, marks=RunIf(linux_only=True)),
53+
pytest.param(1, 2, float("inf"), False, marks=RunIf(linux_only=True)),
54+
pytest.param(0, 2, float("-inf"), True, marks=RunIf(linux_only=True)),
55+
pytest.param(0, 2, float("-inf"), False, marks=RunIf(linux_only=True)),
56+
pytest.param(1, 2, float("-inf"), True, marks=RunIf(linux_only=True)),
57+
pytest.param(1, 2, float("-inf"), False, marks=RunIf(linux_only=True)),
58+
pytest.param(0, 2, float("NaN"), True, marks=RunIf(linux_only=True)),
59+
pytest.param(0, 2, float("NaN"), False, marks=RunIf(linux_only=True)),
60+
pytest.param(1, 2, float("NaN"), True, marks=RunIf(linux_only=True)),
61+
pytest.param(1, 2, float("NaN"), False, marks=RunIf(linux_only=True)),
18862
],
18963
)
19064
@pytest.mark.skipif(not _TORCHMETRICS_GREATER_EQUAL_1_0_0, reason="requires torchmetrics>=1.0.0")

0 commit comments

Comments
 (0)