99import  contextlib 
1010import  enum 
1111import  warnings 
12- from  collections .abc  import  Callable , Iterator , Sequence 
12+ from  collections .abc  import  Callable , Generator ,  Iterator , Sequence 
1313from  functools  import  wraps 
1414from  types  import  ModuleType 
1515from  typing  import  TYPE_CHECKING , Any , ParamSpec , TypeVar , cast 
@@ -216,8 +216,11 @@ def test_myfunc(xp):
216216
217217
218218def  patch_lazy_xp_functions (
219-     request : pytest .FixtureRequest , monkeypatch : pytest .MonkeyPatch , * , xp : ModuleType 
220- ) ->  None :
219+     request : pytest .FixtureRequest ,
220+     monkeypatch : pytest .MonkeyPatch  |  None  =  None ,
221+     * ,
222+     xp : ModuleType ,
223+ ) ->  contextlib .AbstractContextManager [None ]:
221224    """ 
222225    Test lazy execution of functions tagged with :func:`lazy_xp_function`. 
223226
@@ -233,10 +236,15 @@ def patch_lazy_xp_functions(
233236    This function should be typically called by your library's `xp` fixture that runs 
234237    tests on multiple backends:: 
235238
236-         @pytest.fixture(params=[numpy, array_api_strict, jax.numpy, dask.array]) 
237-         def xp(request, monkeypatch): 
238-             patch_lazy_xp_functions(request, monkeypatch, xp=request.param) 
239-             return request.param 
239+         @pytest.fixture(params=[ 
240+             numpy, 
241+             array_api_strict, 
242+             pytest.param(jax.numpy, marks=pytest.mark.thread_unsafe), 
243+             pytest.param(dask.array, marks=pytest.mark.thread_unsafe), 
244+         ]) 
245+         def xp(request): 
246+             with patch_lazy_xp_functions(request, xp=request.param): 
247+                 yield request.param 
240248
241249    but it can be otherwise be called by the test itself too. 
242250
@@ -245,18 +253,50 @@ def xp(request, monkeypatch):
245253    request : pytest.FixtureRequest 
246254        Pytest fixture, as acquired by the test itself or by one of its fixtures. 
247255    monkeypatch : pytest.MonkeyPatch 
248-         Pytest fixture, as acquired by the test itself or by one of its fixtures.  
256+         Deprecated  
249257    xp : array_namespace 
250258        Array namespace to be tested. 
251259
252260    See Also 
253261    -------- 
254262    lazy_xp_function : Tag a function to be tested on lazy backends. 
255263    pytest.FixtureRequest : `request` test function parameter. 
264+ 
265+     Notes 
266+     ----- 
267+     This context manager monkey-patches modules and as such is thread unsafe 
268+     on Dask and JAX. If you run your test suite with 
269+     `pytest-run-parallel <https://github.com/Quansight-Labs/pytest-run-parallel/>`_, 
270+     you should mark these backends with ``@pytest.mark.thread_unsafe``, as shown in 
271+     the example above. 
256272    """ 
257273    mod  =  cast (ModuleType , request .module )
258274    mods  =  [mod , * cast (list [ModuleType ], getattr (mod , "lazy_xp_modules" , []))]
259275
276+     to_revert : list [tuple [ModuleType , str , object ]] =  []
277+ 
278+     def  temp_setattr (mod : ModuleType , name : str , func : object ) ->  None :
279+         """ 
280+         Variant of monkeypatch.setattr, which allows monkey-patching only selected 
281+         parameters of a test so that pytest-run-parallel can run on the remainder. 
282+         """ 
283+         assert  hasattr (mod , name )
284+         to_revert .append ((mod , name , getattr (mod , name )))
285+         setattr (mod , name , func )
286+ 
287+     if  monkeypatch  is  not   None :
288+         warnings .warn (
289+             (
290+                 "The `monkeypatch` parameter is deprecated and will be removed in a " 
291+                 "future version. " 
292+                 "Use `patch_lazy_xp_function` as a context manager instead." 
293+             ),
294+             DeprecationWarning ,
295+             stacklevel = 2 ,
296+         )
297+         # Enable using patch_lazy_xp_function not as a context manager 
298+         temp_setattr  =  monkeypatch .setattr   # type: ignore[assignment]  # pyright: ignore[reportAssignmentType] 
299+ 
260300    def  iter_tagged () ->  (
261301        Iterator [tuple [ModuleType , str , Callable [..., Any ], dict [str , Any ]]]
262302    ):
@@ -279,13 +319,26 @@ def iter_tagged() -> (
279319            elif  n  is  False :
280320                n  =  0 
281321            wrapped  =  _dask_wrap (func , n )
282-             monkeypatch . setattr (mod , name , wrapped )
322+             temp_setattr (mod , name , wrapped )
283323
284324    elif  is_jax_namespace (xp ):
285325        for  mod , name , func , tags  in  iter_tagged ():
286326            if  tags ["jax_jit" ]:
287327                wrapped  =  jax_autojit (func )
288-                 monkeypatch .setattr (mod , name , wrapped )
328+                 temp_setattr (mod , name , wrapped )
329+ 
330+     # We can't just decorate patch_lazy_xp_functions with 
331+     # @contextlib.contextmanager because it would not work with the 
332+     # deprecated monkeypatch when not used as a context manager. 
333+     @contextlib .contextmanager  
334+     def  revert_on_exit () ->  Generator [None ]:
335+         try :
336+             yield 
337+         finally :
338+             for  mod , name , orig_func  in  to_revert :
339+                 setattr (mod , name , orig_func )
340+ 
341+     return  revert_on_exit ()
289342
290343
291344class  CountingDaskScheduler (SchedulerGetCallable ):
0 commit comments