Skip to content

Commit 7595de2

Browse files
committed
Fix warning assertions for pytest 8 compatibility
1 parent 8c1b362 commit 7595de2

File tree

4 files changed

+49
-19
lines changed

4 files changed

+49
-19
lines changed

tests/conftest.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,23 @@
1+
import warnings
2+
from contextlib import contextmanager
13
from typing import Any
24

5+
import pytest
6+
37
from tpcp import BaseTpcpObject
48
from tpcp._base import _BaseTpcpObject
59

610

711
def _get_params_without_nested_class(instance: BaseTpcpObject) -> dict[str, Any]:
812
return {k: v for k, v in instance.get_params().items() if not isinstance(v, _BaseTpcpObject)}
13+
14+
15+
@contextmanager
16+
def warns_or_none(expected_warning):
17+
if expected_warning is None:
18+
with warnings.catch_warnings(record=True) as caught:
19+
warnings.simplefilter("always")
20+
yield caught
21+
else:
22+
with pytest.warns(expected_warning) as caught:
23+
yield caught

tests/test_caching.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ def test_caching_twice_same_instance(self, example_class):
106106
getattr(example, action_name)(2)
107107
assert example.result_1_ == 5 * multiplier
108108

109-
with pytest.warns(None) as w:
109+
with warnings.catch_warnings(record=True) as w:
110+
warnings.simplefilter("always")
110111
getattr(example, action_name)(3)
111112
assert example.result_1_ == 6 * multiplier
112113
assert not w
@@ -126,7 +127,8 @@ def test_caching_twice_new_instance(self, example_class):
126127
assert example.result_1_ == 5 * multiplier
127128

128129
example = example_class(1, 2)
129-
with pytest.warns(None) as w:
130+
with warnings.catch_warnings(record=True) as w:
131+
warnings.simplefilter("always")
130132
getattr(example, action_name)(3)
131133
assert example.result_1_ == 6 * multiplier
132134
assert not w
@@ -165,7 +167,8 @@ def test_cache_only(self, example_class):
165167
example = example.clone()
166168

167169
# Now in the cached version
168-
with pytest.warns(None) as w:
170+
with warnings.catch_warnings(record=True) as w:
171+
warnings.simplefilter("always")
169172
getattr(example, action_name)(2)
170173
assert not w
171174
assert example.result_1_ == 5 * multiplier
@@ -197,7 +200,8 @@ def worker_func(pipe):
197200
if self.cache_method_name == "disk" and restore_in_parallel_process is True:
198201
# Disk cache can work across processes. This means, already on the first call in the new process,
199202
# we should get the cached result.
200-
with pytest.warns(None) as w:
203+
with warnings.catch_warnings(record=True) as w:
204+
warnings.simplefilter("always")
201205
pipe.action(1)
202206
assert not w
203207
else:
@@ -207,7 +211,8 @@ def worker_func(pipe):
207211

208212
if restore_in_parallel_process is True:
209213
# Id we set the restore option to True, the second call should be correctly cached
210-
with pytest.warns(None) as w:
214+
with warnings.catch_warnings(record=True) as w:
215+
warnings.simplefilter("always")
211216
pipe.action(1)
212217
assert not w
213218
else:
@@ -255,7 +260,8 @@ def test_joblib_only(self, joblib_cache, hybrid_cache_clear):
255260

256261
assert r == 3
257262

258-
with pytest.warns(None) as w:
263+
with warnings.catch_warnings(record=True) as w:
264+
warnings.simplefilter("always")
259265
r = cached_func(1, 2)
260266

261267
assert r == 3
@@ -269,7 +275,8 @@ def test_lru_only(self, hybrid_cache_clear):
269275

270276
assert r == 3
271277

272-
with pytest.warns(None) as w:
278+
with warnings.catch_warnings(record=True) as w:
279+
warnings.simplefilter("always")
273280
r = cached_func(1, 2)
274281

275282
assert r == 3
@@ -288,7 +295,8 @@ def test_staggered_cache(self, joblib_cache_verbose, hybrid_cache_clear, capfd):
288295

289296
assert r == 3
290297

291-
with pytest.warns(None) as w:
298+
with warnings.catch_warnings(record=True) as w:
299+
warnings.simplefilter("always")
292300
r = cached_func(1, 2)
293301

294302
# This should not hit the joblib cache, as the lru cache should have been used
@@ -317,7 +325,8 @@ def test_joblib_cache_survives_clear(self, joblib_cache_verbose, hybrid_cache_cl
317325

318326
cached_func_new = hybrid_cache(joblib_cache_verbose, 2)(example_func)
319327

320-
with pytest.warns(None) as w:
328+
with warnings.catch_warnings(record=True) as w:
329+
warnings.simplefilter("always")
321330
r = cached_func_new(1, 2)
322331

323332
# This time this should hit the joblib cache, as the lru cache should have been cleared
@@ -330,7 +339,8 @@ def test_joblib_cache_survives_clear(self, joblib_cache_verbose, hybrid_cache_cl
330339
assert not w
331340

332341
# And now the lru cache should be used again
333-
with pytest.warns(None) as w:
342+
with warnings.catch_warnings(record=True) as w:
343+
warnings.simplefilter("always")
334344
r = cached_func_new(1, 2)
335345

336346
# This time this should hit the joblib cache, as the lru cache should have been cleared

tests/test_pipelines/test_optimize.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from tempfile import TemporaryDirectory
23
from typing import Union
34
from unittest.mock import patch
@@ -8,6 +9,7 @@
89
import pytest
910
from sklearn.model_selection import ParameterGrid, PredefinedSplit
1011

12+
from tests.conftest import warns_or_none
1113
from tests.test_pipelines.conftest import (
1214
DummyDataset,
1315
DummyOptimizablePipeline,
@@ -147,7 +149,7 @@ def test_return_optimized_single(self, return_optimized):
147149
if isinstance(return_optimized, str) and not return_optimized.endswith("score"):
148150
warning = UserWarning
149151

150-
with pytest.warns(warning) as w:
152+
with warns_or_none(warning) as w:
151153
gs.optimize(DummyDataset())
152154

153155
if isinstance(return_optimized, str) and not return_optimized.endswith("score"):
@@ -760,7 +762,7 @@ def test_safe_optimize(self, use_safe):
760762
with patch.object(DummyOptimizablePipelineUnsafe, "self_optimize", return_value=optimized_pipe) as mock:
761763
mock.__name__ = "self_optimize"
762764
warning = PotentialUserErrorWarning if use_safe else None
763-
with pytest.warns(warning) as w:
765+
with warns_or_none(warning) as w:
764766
Optimize(DummyOptimizablePipelineUnsafe(), safe_optimize=use_safe).optimize(ds)
765767

766768
if use_safe:
@@ -818,16 +820,18 @@ def test_warning(self):
818820

819821
assert len(w.list) == 1
820822

821-
with pytest.warns(None) as w:
823+
with warnings.catch_warnings(record=True) as w:
824+
warnings.simplefilter("always")
822825
DummyOptimize(DummyPipeline()).optimize(dataset=None)
823826

824-
assert len(w.list) == 0
827+
assert len(w) == 0
825828

826829
def test_warning_suppression(self):
827-
with pytest.warns(None) as w:
830+
with warnings.catch_warnings(record=True) as w:
831+
warnings.simplefilter("always")
828832
DummyOptimize(DummyOptimizablePipeline(), ignore_potential_user_error_warning=True).optimize(dataset=None)
829833

830-
assert len(w.list) == 0
834+
assert len(w) == 0
831835

832836

833837
class TestOptimizeBase:

tests/test_safe_decorator.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import pytest
55

6+
from tests.conftest import warns_or_none
67
from tests.test_pipelines.conftest import DummyDataset
78
from tpcp import (
89
Dataset,
@@ -90,7 +91,7 @@ def test_wrapper_checks_name(self, name, warn):
9091
ds = DummyDataset()
9192

9293
warning = PotentialUserErrorWarning if warn else None
93-
with pytest.warns(warning) as w:
94+
with warns_or_none(warning) as w:
9495
make_action_safe(test_func)(DummyActionPipelineUnsafe(), ds)
9596

9697
if warn:
@@ -148,7 +149,7 @@ def test_wrapper_checks_name(self, name, warn):
148149
ds = DummyDataset()
149150

150151
warning = PotentialUserErrorWarning if warn else None
151-
with pytest.warns(warning) as w:
152+
with warns_or_none(warning) as w:
152153
make_optimize_safe(test_func)(DummyOptimizablePipelineUnsafe(), ds)
153154

154155
if warn:
@@ -166,7 +167,7 @@ def test_optimize_warns(self, output, warn):
166167
DummyOptimizablePipelineUnsafe.self_optimize
167168
)
168169
warning = PotentialUserErrorWarning if warn else None
169-
with pytest.warns(warning) as w:
170+
with warns_or_none(warning) as w:
170171
DummyOptimizablePipelineUnsafe().self_optimize(ds)
171172

172173
if len(w) > 0:

0 commit comments

Comments
 (0)