99import contextlib
1010import enum
1111import warnings
12- from collections .abc import Callable , Iterator , Sequence
12+ from collections .abc import Callable , Generator , Iterator , Sequence
1313from functools import wraps
1414from types import ModuleType
1515from typing import TYPE_CHECKING , Any , ParamSpec , TypeVar , cast
@@ -216,8 +216,11 @@ def test_myfunc(xp):
216216
217217
218218def patch_lazy_xp_functions (
219- request : pytest .FixtureRequest , monkeypatch : pytest .MonkeyPatch , * , xp : ModuleType
220- ) -> None :
219+ request : pytest .FixtureRequest ,
220+ monkeypatch : pytest .MonkeyPatch | None = None ,
221+ * ,
222+ xp : ModuleType ,
223+ ) -> contextlib .AbstractContextManager [None ]:
221224 """
222225 Test lazy execution of functions tagged with :func:`lazy_xp_function`.
223226
@@ -233,10 +236,15 @@ def patch_lazy_xp_functions(
233236 This function should be typically called by your library's `xp` fixture that runs
234237 tests on multiple backends::
235238
236- @pytest.fixture(params=[numpy, array_api_strict, jax.numpy, dask.array])
237- def xp(request, monkeypatch):
238- patch_lazy_xp_functions(request, monkeypatch, xp=request.param)
239- return request.param
239+ @pytest.fixture(params=[
240+ numpy,
241+ array_api_strict,
242+ pytest.param(jax.numpy, marks=pytest.mark.thread_unsafe),
243+ pytest.param(dask.array, marks=pytest.mark.thread_unsafe),
244+ ])
245+ def xp(request):
246+ with patch_lazy_xp_functions(request, xp=request.param):
247+ yield request.param
240248
241249 but it can be otherwise be called by the test itself too.
242250
@@ -245,18 +253,49 @@ def xp(request, monkeypatch):
245253 request : pytest.FixtureRequest
246254 Pytest fixture, as acquired by the test itself or by one of its fixtures.
247255 monkeypatch : pytest.MonkeyPatch
248- Pytest fixture, as acquired by the test itself or by one of its fixtures.
256+ Deprecated
249257 xp : array_namespace
250258 Array namespace to be tested.
251259
252260 See Also
253261 --------
254262 lazy_xp_function : Tag a function to be tested on lazy backends.
255263 pytest.FixtureRequest : `request` test function parameter.
264+
265+ Notes
266+ -----
267+ This context manager is thread unsafe on Dask and JAX. If you run your test suite
268+ with
269+ `pytest-run-parallel <https://github.com/Quansight-Labs/pytest-run-parallel/>`_,
270+ you should mark the test or the fixture with ``@pytest.mark.thread_unsafe``.
256271 """
257272 mod = cast (ModuleType , request .module )
258273 mods = [mod , * cast (list [ModuleType ], getattr (mod , "lazy_xp_modules" , []))]
259274
275+ to_revert : list [tuple [ModuleType , str , object ]] = []
276+
277+ def temp_setattr (mod : ModuleType , name : str , func : object ) -> None :
278+ """
279+ Variant of monkeypatch.setattr, which allows monkey-patching only selected
280+ parameters of a test so that pytest-run-parallel can run on the remainder.
281+ """
282+ assert hasattr (mod , name )
283+ to_revert .append ((mod , name , getattr (mod , name )))
284+ setattr (mod , name , func )
285+
286+ if monkeypatch is not None :
287+ warnings .warn (
288+ (
289+ "The `monkeypatch` parameter is deprecated and will be removed in a "
290+ "future version. "
291+ "Use `patch_lazy_xp_function` as a context manager instead."
292+ ),
293+ DeprecationWarning ,
294+ stacklevel = 2 ,
295+ )
296+ # Enable using patch_lazy_xp_function not as a context manager
297+ temp_setattr = monkeypatch .setattr # type: ignore[assignment] # pyright: ignore[reportAssignmentType]
298+
260299 def iter_tagged () -> ( # type: ignore[explicit-any]
261300 Iterator [tuple [ModuleType , str , Callable [..., Any ], dict [str , Any ]]]
262301 ):
@@ -279,13 +318,26 @@ def iter_tagged() -> ( # type: ignore[explicit-any]
279318 elif n is False :
280319 n = 0
281320 wrapped = _dask_wrap (func , n )
282- monkeypatch . setattr (mod , name , wrapped )
321+ temp_setattr (mod , name , wrapped )
283322
284323 elif is_jax_namespace (xp ):
285324 for mod , name , func , tags in iter_tagged ():
286325 if tags ["jax_jit" ]:
287326 wrapped = jax_autojit (func )
288- monkeypatch .setattr (mod , name , wrapped )
327+ temp_setattr (mod , name , wrapped )
328+
329+ # We can't just decorate patch_lazy_xp_functions with
330+ # @contextlib.contextmanager because it would not work with the
331+ # deprecated monkeypatch when not used as a context manager.
332+ @contextlib .contextmanager
333+ def revert_on_exit () -> Generator [None ]:
334+ try :
335+ yield
336+ finally :
337+ for mod , name , orig_func in to_revert :
338+ setattr (mod , name , orig_func )
339+
340+ return revert_on_exit ()
289341
290342
291343class CountingDaskScheduler (SchedulerGetCallable ):
0 commit comments