99import contextlib
1010import enum
1111import warnings
12- from collections .abc import Callable , Iterator , Sequence
12+ from collections .abc import Callable , Generator , Iterator , Sequence
1313from functools import wraps
1414from types import ModuleType
1515from typing import TYPE_CHECKING , Any , ParamSpec , TypeVar , cast
@@ -216,8 +216,11 @@ def test_myfunc(xp):
216216
217217
218218def patch_lazy_xp_functions (
219- request : pytest .FixtureRequest , monkeypatch : pytest .MonkeyPatch , * , xp : ModuleType
220- ) -> None :
219+ request : pytest .FixtureRequest ,
220+ monkeypatch : pytest .MonkeyPatch | None = None ,
221+ * ,
222+ xp : ModuleType ,
223+ ) -> contextlib .AbstractContextManager [None ]:
221224 """
222225 Test lazy execution of functions tagged with :func:`lazy_xp_function`.
223226
@@ -233,10 +236,15 @@ def patch_lazy_xp_functions(
233236 This function should be typically called by your library's `xp` fixture that runs
234237 tests on multiple backends::
235238
236- @pytest.fixture(params=[numpy, array_api_strict, jax.numpy, dask.array])
237- def xp(request, monkeypatch):
238- patch_lazy_xp_functions(request, monkeypatch, xp=request.param)
239- return request.param
239+ @pytest.fixture(params=[
240+ numpy,
241+ array_api_strict,
242+ pytest.param(jax.numpy, marks=pytest.mark.thread_unsafe),
243+ pytest.param(dask.array, marks=pytest.mark.thread_unsafe),
244+ ])
245+ def xp(request):
246+ with patch_lazy_xp_functions(request, xp=request.param):
247+ yield request.param
240248
241249 but it can be otherwise be called by the test itself too.
242250
@@ -245,18 +253,50 @@ def xp(request, monkeypatch):
245253 request : pytest.FixtureRequest
246254 Pytest fixture, as acquired by the test itself or by one of its fixtures.
247255 monkeypatch : pytest.MonkeyPatch
248- Pytest fixture, as acquired by the test itself or by one of its fixtures.
256+ Deprecated
249257 xp : array_namespace
250258 Array namespace to be tested.
251259
252260 See Also
253261 --------
254262 lazy_xp_function : Tag a function to be tested on lazy backends.
255263 pytest.FixtureRequest : `request` test function parameter.
264+
265+ Notes
266+ -----
267+ This context manager monkey-patches modules and as such is thread unsafe
268+ on Dask and JAX. If you run your test suite with
269+ `pytest-run-parallel <https://github.com/Quansight-Labs/pytest-run-parallel/>`_,
270+ you should mark these backends with ``@pytest.mark.thread_unsafe``, as shown in
271+ the example above.
256272 """
257273 mod = cast (ModuleType , request .module )
258274 mods = [mod , * cast (list [ModuleType ], getattr (mod , "lazy_xp_modules" , []))]
259275
276+ to_revert : list [tuple [ModuleType , str , object ]] = []
277+
278+ def temp_setattr (mod : ModuleType , name : str , func : object ) -> None :
279+ """
280+ Variant of monkeypatch.setattr, which allows monkey-patching only selected
281+ parameters of a test so that pytest-run-parallel can run on the remainder.
282+ """
283+ assert hasattr (mod , name )
284+ to_revert .append ((mod , name , getattr (mod , name )))
285+ setattr (mod , name , func )
286+
287+ if monkeypatch is not None :
288+ warnings .warn (
289+ (
290+ "The `monkeypatch` parameter is deprecated and will be removed in a "
291+ "future version. "
292+ "Use `patch_lazy_xp_function` as a context manager instead."
293+ ),
294+ DeprecationWarning ,
295+ stacklevel = 2 ,
296+ )
297+ # Enable using patch_lazy_xp_function not as a context manager
298+ temp_setattr = monkeypatch .setattr # type: ignore[assignment] # pyright: ignore[reportAssignmentType]
299+
260300 def iter_tagged () -> ( # type: ignore[explicit-any]
261301 Iterator [tuple [ModuleType , str , Callable [..., Any ], dict [str , Any ]]]
262302 ):
@@ -279,13 +319,26 @@ def iter_tagged() -> ( # type: ignore[explicit-any]
279319 elif n is False :
280320 n = 0
281321 wrapped = _dask_wrap (func , n )
282- monkeypatch . setattr (mod , name , wrapped )
322+ temp_setattr (mod , name , wrapped )
283323
284324 elif is_jax_namespace (xp ):
285325 for mod , name , func , tags in iter_tagged ():
286326 if tags ["jax_jit" ]:
287327 wrapped = jax_autojit (func )
288- monkeypatch .setattr (mod , name , wrapped )
328+ temp_setattr (mod , name , wrapped )
329+
330+ # We can't just decorate patch_lazy_xp_functions with
331+ # @contextlib.contextmanager because it would not work with the
332+ # deprecated monkeypatch when not used as a context manager.
333+ @contextlib .contextmanager
334+ def revert_on_exit () -> Generator [None ]:
335+ try :
336+ yield
337+ finally :
338+ for mod , name , orig_func in to_revert :
339+ setattr (mod , name , orig_func )
340+
341+ return revert_on_exit ()
289342
290343
291344class CountingDaskScheduler (SchedulerGetCallable ):
0 commit comments