@@ -56,9 +56,8 @@ def lazy_xp_function( # type: ignore[no-any-explicit]
5656 """
5757 Tag a function to be tested on lazy backends.
5858
59- Tag a function, which must be imported in the test module globals, so that when any
60- tests defined in the same module are executed with ``xp=jax.numpy`` the function is
61- replaced with a jitted version of itself, and when it is executed with
59+ Tag a function so that when any tests are executed with ``xp=jax.numpy`` the
60+ function is replaced with a jitted version of itself, and when it is executed with
6261 ``xp=dask.array`` the function will raise if it attempts to materialize the graph.
6362 This will be later expanded to provide test coverage for other lazy backends.
6463
@@ -120,19 +119,59 @@ def test_myfunc(xp):
120119
121120 Notes
122121 -----
123- A test function can circumvent this monkey-patching system by calling `func` as an
124- attribute of the original module. You need to sanitize your code to make sure this
125- does not happen .
122+ In order for this tag to be effective, the test function must be imported into the
123+ test module globals without its namespace; alternatively its namespace must be
124+ declared in a ``lazy_xp_modules`` list in the test module globals .
126125
127- Example::
126+ Example 1 ::
128127
129- import mymodule from mymodule import myfunc
128+ from mymodule import myfunc
130129
131130 lazy_xp_function(myfunc)
132131
133132 def test_myfunc(xp):
134- a = xp.asarray([1, 2]) b = myfunc(a) # This is jitted when xp=jax.numpy c =
135- mymodule.myfunc(a) # This is not
133+ x = myfunc(xp.asarray([1, 2]))
134+
135+ Example 2::
136+
137+ import mymodule
138+
139+ lazy_xp_modules = [mymodule]
140+ lazy_xp_function(mymodule.myfunc)
141+
142+ def test_myfunc(xp):
143+ x = mymodule.myfunc(xp.asarray([1, 2]))
144+
145+ A test function can circumvent this monkey-patching system by using a namespace
146+ outside of the two above patterns. You need to sanitize your code to make sure this
147+ only happens intentionally.
148+
149+ Example 1::
150+
151+ import mymodule
152+ from mymodule import myfunc
153+
154+ lazy_xp_function(myfunc)
155+
156+ def test_myfunc(xp):
157+ a = xp.asarray([1, 2])
158+ b = myfunc(a) # This is wrapped when xp=jax.numpy or xp=dask.array
159+ c = mymodule.myfunc(a) # This is not
160+
161+ Example 2::
162+
163+ import mymodule
164+
165+ class naked:
166+ myfunc = mymodule.myfunc
167+
168+ lazy_xp_modules = [mymodule]
169+ lazy_xp_function(mymodule.myfunc)
170+
171+ def test_myfunc(xp):
172+ a = xp.asarray([1, 2])
173+ b = mymodule.myfunc(a) # This is wrapped when xp=jax.numpy or xp=dask.array
174+ c = naked.myfunc(a) # This is not
136175 """
137176 tags = {
138177 "allow_dask_compute" : allow_dask_compute ,
@@ -153,11 +192,13 @@ def patch_lazy_xp_functions(
153192 Test lazy execution of functions tagged with :func:`lazy_xp_function`.
154193
155194 If ``xp==jax.numpy``, search for all functions which have been tagged with
156- :func:`lazy_xp_function` in the globals of the module that defines the current test
195+ :func:`lazy_xp_function` in the globals of the module that defines the current test,
196+ as well as in the ``lazy_xp_modules`` list in the globals of the same module,
157197 and wrap them with :func:`jax.jit`. Unwrap them at the end of the test.
158198
159199 If ``xp==dask.array``, wrap the functions with a decorator that disables
160- ``compute()`` and ``persist()``.
200+ ``compute()`` and ``persist()`` and ensures that exceptions and warnings are raised
201+ eagerly.
161202
162203 This function should be typically called by your library's `xp` fixture that runs
163204 tests on multiple backends::
@@ -183,29 +224,33 @@ def xp(request, monkeypatch):
183224 lazy_xp_function : Tag a function to be tested on lazy backends.
184225 pytest.FixtureRequest : `request` test function parameter.
185226 """
186- globals_ = cast ("dict[str, Any]" , request .module .__dict__ ) # type: ignore[no-any-explicit]
187-
188- def iter_tagged () -> Iterator [tuple [str , Callable [..., Any ], dict [str , Any ]]]: # type: ignore[no-any-explicit]
189- for name , func in globals_ .items ():
190- tags : dict [str , Any ] | None = None # type: ignore[no-any-explicit]
191- with contextlib .suppress (AttributeError ):
192- tags = func ._lazy_xp_function # pylint: disable=protected-access
193- if tags is None :
194- with contextlib .suppress (KeyError , TypeError ):
195- tags = _ufuncs_tags [func ]
196- if tags is not None :
197- yield name , func , tags
227+ mod = cast (ModuleType , request .module )
228+ mods = [mod , * cast (list [ModuleType ], getattr (mod , "lazy_xp_modules" , []))]
229+
230+ def iter_tagged () -> ( # type: ignore[no-any-explicit]
231+ Iterator [tuple [ModuleType , str , Callable [..., Any ], dict [str , Any ]]]
232+ ):
233+ for mod in mods :
234+ for name , func in mod .__dict__ .items ():
235+ tags : dict [str , Any ] | None = None # type: ignore[no-any-explicit]
236+ with contextlib .suppress (AttributeError ):
237+ tags = func ._lazy_xp_function # pylint: disable=protected-access
238+ if tags is None :
239+ with contextlib .suppress (KeyError , TypeError ):
240+ tags = _ufuncs_tags [func ]
241+ if tags is not None :
242+ yield mod , name , func , tags
198243
199244 if is_dask_namespace (xp ):
200- for name , func , tags in iter_tagged ():
245+ for mod , name , func , tags in iter_tagged ():
201246 n = tags ["allow_dask_compute" ]
202247 wrapped = _dask_wrap (func , n )
203- monkeypatch .setitem ( globals_ , name , wrapped )
248+ monkeypatch .setattr ( mod , name , wrapped )
204249
205250 elif is_jax_namespace (xp ):
206251 import jax
207252
208- for name , func , tags in iter_tagged ():
253+ for mod , name , func , tags in iter_tagged ():
209254 if tags ["jax_jit" ]:
210255 # suppress unused-ignore to run mypy in -e lint as well as -e dev
211256 wrapped = cast ( # type: ignore[no-any-explicit]
@@ -216,7 +261,7 @@ def iter_tagged() -> Iterator[tuple[str, Callable[..., Any], dict[str, Any]]]:
216261 static_argnames = tags ["static_argnames" ],
217262 ),
218263 )
219- monkeypatch .setitem ( globals_ , name , wrapped )
264+ monkeypatch .setattr ( mod , name , wrapped )
220265
221266
222267class CountingDaskScheduler (SchedulerGetCallable ):
0 commit comments