11import contextlib
2- import sys
32
43import pytest
54import torch
65
76from lightning .fabric import Fabric
87from lightning .fabric .utilities .spike import _TORCHMETRICS_GREATER_EQUAL_1_0_0 , SpikeDetection , TrainingSpikeException
8+ from tests_fabric .helpers .runif import RunIf
99
1010
1111def spike_detection_test (fabric , global_rank_spike , spike_value , should_raise ):
@@ -32,6 +32,8 @@ def spike_detection_test(fabric, global_rank_spike, spike_value, should_raise):
3232@pytest .mark .flaky (max_runs = 3 )
3333@pytest .mark .parametrize (
3434 ("global_rank_spike" , "num_devices" , "spike_value" , "finite_only" ),
35+ # NOTE FOR ALL FOLLOWING TESTS:
36+ # adding run on linux only because multiprocessing on other platforms takes forever
3537 [
3638 pytest .param (0 , 1 , None , True ),
3739 pytest .param (0 , 1 , None , False ),
@@ -41,150 +43,22 @@ def spike_detection_test(fabric, global_rank_spike, spike_value, should_raise):
4143 pytest .param (0 , 1 , float ("-inf" ), False ),
4244 pytest .param (0 , 1 , float ("NaN" ), True ),
4345 pytest .param (0 , 1 , float ("NaN" ), False ),
44- pytest .param (
45- 0 ,
46- 2 ,
47- None ,
48- True ,
49- marks = pytest .mark .skipif (
50- sys .platform != "linux" , reason = "multiprocessing on other platforms takes forever"
51- ),
52- ),
53- pytest .param (
54- 0 ,
55- 2 ,
56- None ,
57- False ,
58- marks = pytest .mark .skipif (
59- sys .platform != "linux" , reason = "multiprocessing on other platforms takes forever"
60- ),
61- ),
62- pytest .param (
63- 1 ,
64- 2 ,
65- None ,
66- True ,
67- marks = pytest .mark .skipif (
68- sys .platform != "linux" , reason = "multiprocessing on other platforms takes forever"
69- ),
70- ),
71- pytest .param (
72- 1 ,
73- 2 ,
74- None ,
75- False ,
76- marks = pytest .mark .skipif (
77- sys .platform != "linux" , reason = "multiprocessing on other platforms takes forever"
78- ),
79- ),
80- pytest .param (
81- 0 ,
82- 2 ,
83- float ("inf" ),
84- True ,
85- marks = pytest .mark .skipif (
86- sys .platform != "linux" , reason = "multiprocessing on other platforms takes forever"
87- ),
88- ),
89- pytest .param (
90- 0 ,
91- 2 ,
92- float ("inf" ),
93- False ,
94- marks = pytest .mark .skipif (
95- sys .platform != "linux" , reason = "multiprocessing on other platforms takes forever"
96- ),
97- ),
98- pytest .param (
99- 1 ,
100- 2 ,
101- float ("inf" ),
102- True ,
103- marks = pytest .mark .skipif (
104- sys .platform != "linux" , reason = "multiprocessing on other platforms takes forever"
105- ),
106- ),
107- pytest .param (
108- 1 ,
109- 2 ,
110- float ("inf" ),
111- False ,
112- marks = pytest .mark .skipif (
113- sys .platform != "linux" , reason = "multiprocessing on other platforms takes forever"
114- ),
115- ),
116- pytest .param (
117- 0 ,
118- 2 ,
119- float ("-inf" ),
120- True ,
121- marks = pytest .mark .skipif (
122- sys .platform != "linux" , reason = "multiprocessing on other platforms takes forever"
123- ),
124- ),
125- pytest .param (
126- 0 ,
127- 2 ,
128- float ("-inf" ),
129- False ,
130- marks = pytest .mark .skipif (
131- sys .platform != "linux" , reason = "multiprocessing on other platforms takes forever"
132- ),
133- ),
134- pytest .param (
135- 1 ,
136- 2 ,
137- float ("-inf" ),
138- True ,
139- marks = pytest .mark .skipif (
140- sys .platform != "linux" , reason = "multiprocessing on other platforms takes forever"
141- ),
142- ),
143- pytest .param (
144- 1 ,
145- 2 ,
146- float ("-inf" ),
147- False ,
148- marks = pytest .mark .skipif (
149- sys .platform != "linux" , reason = "multiprocessing on other platforms takes forever"
150- ),
151- ),
152- pytest .param (
153- 0 ,
154- 2 ,
155- float ("NaN" ),
156- True ,
157- marks = pytest .mark .skipif (
158- sys .platform != "linux" , reason = "multiprocessing on other platforms takes forever"
159- ),
160- ),
161- pytest .param (
162- 0 ,
163- 2 ,
164- float ("NaN" ),
165- False ,
166- marks = pytest .mark .skipif (
167- sys .platform != "linux" , reason = "multiprocessing on other platforms takes forever"
168- ),
169- ),
170- pytest .param (
171- 1 ,
172- 2 ,
173- float ("NaN" ),
174- True ,
175- marks = pytest .mark .skipif (
176- sys .platform != "linux" , reason = "multiprocessing on other platforms takes forever"
177- ),
178- ),
179- pytest .param (
180- 1 ,
181- 2 ,
182- float ("NaN" ),
183- False ,
184- marks = pytest .mark .skipif (
185- sys .platform != "linux" , reason = "multiprocessing on other platforms takes forever"
186- ),
187- ),
46+ pytest .param (0 , 2 , None , True , marks = RunIf (linux_only = True )),
47+ pytest .param (0 , 2 , None , False , marks = RunIf (linux_only = True )),
48+ pytest .param (1 , 2 , None , True , marks = RunIf (linux_only = True )),
49+ pytest .param (1 , 2 , None , False , marks = RunIf (linux_only = True )),
50+ pytest .param (0 , 2 , float ("inf" ), True , marks = RunIf (linux_only = True )),
51+ pytest .param (0 , 2 , float ("inf" ), False , marks = RunIf (linux_only = True )),
52+ pytest .param (1 , 2 , float ("inf" ), True , marks = RunIf (linux_only = True )),
53+ pytest .param (1 , 2 , float ("inf" ), False , marks = RunIf (linux_only = True )),
54+ pytest .param (0 , 2 , float ("-inf" ), True , marks = RunIf (linux_only = True )),
55+ pytest .param (0 , 2 , float ("-inf" ), False , marks = RunIf (linux_only = True )),
56+ pytest .param (1 , 2 , float ("-inf" ), True , marks = RunIf (linux_only = True )),
57+ pytest .param (1 , 2 , float ("-inf" ), False , marks = RunIf (linux_only = True )),
58+ pytest .param (0 , 2 , float ("NaN" ), True , marks = RunIf (linux_only = True )),
59+ pytest .param (0 , 2 , float ("NaN" ), False , marks = RunIf (linux_only = True )),
60+ pytest .param (1 , 2 , float ("NaN" ), True , marks = RunIf (linux_only = True )),
61+ pytest .param (1 , 2 , float ("NaN" ), False , marks = RunIf (linux_only = True )),
18862 ],
18963)
19064@pytest .mark .skipif (not _TORCHMETRICS_GREATER_EQUAL_1_0_0 , reason = "requires torchmetrics>=1.0.0" )
0 commit comments