@@ -57,9 +57,8 @@ def lazy_xp_function( # type: ignore[no-any-explicit]
5757 """
5858 Tag a function to be tested on lazy backends.
5959
60- Tag a function, which must be imported in the test module globals, so that when any
61- tests defined in the same module are executed with ``xp=jax.numpy`` the function is
62- replaced with a jitted version of itself, and when it is executed with
60+ Tag a function so that when any tests are executed with ``xp=jax.numpy`` the
61+ function is replaced with a jitted version of itself, and when it is executed with
6362 ``xp=dask.array`` the function will raise if it attempts to materialize the graph.
6463 This will be later expanded to provide test coverage for other lazy backends.
6564
@@ -121,19 +120,59 @@ def test_myfunc(xp):
121120
122121 Notes
123122 -----
124- A test function can circumvent this monkey-patching system by calling `func` as an
125- attribute of the original module. You need to sanitize your code to make sure this
126- does not happen .
123+ In order for this tag to be effective, the test function must be imported into the
124+ test module globals without namespace; alternatively its namespace must be declared
125+ in a ``lazy_xp_modules`` list in the test module globals .
127126
128- Example::
127+ Example 1 ::
129128
130- import mymodule from mymodule import myfunc
129+ from mymodule import myfunc
131130
132131 lazy_xp_function(myfunc)
133132
134133 def test_myfunc(xp):
135- a = xp.asarray([1, 2]) b = myfunc(a) # This is jitted when xp=jax.numpy c =
136- mymodule.myfunc(a) # This is not
134+ x = myfunc(xp.asarray([1, 2]))
135+
136+ Example 2::
137+
138+ import mymodule
139+
140+ lazy_xp_modules = [mymodule]
141+ lazy_xp_function(mymodule.myfunc)
142+
143+ def test_myfunc(xp):
144+ x = mymodule.myfunc(xp.asarray([1, 2]))
145+
146+ A test function can circumvent this monkey-patching system by using a namespace
147+ outside of the two above patterns. You need to sanitize your code to make sure this
148+ only happens intentionally.
149+
150+ Example 1::
151+
152+ import mymodule
153+ from mymodule import myfunc
154+
155+ lazy_xp_function(myfunc)
156+
157+ def test_myfunc(xp):
158+ a = xp.asarray([1, 2])
159+ b = myfunc(a) # This is jitted when xp=jax.numpy
160+ c = mymodule.myfunc(a) # This is not
161+
162+ Example 2::
163+
164+ import mymodule
165+
166+ class naked:
167+ myfunc = mymodule.myfunc
168+
169+ lazy_xp_modules = [mymodule]
170+ lazy_xp_function(mymodule.myfunc)
171+
172+ def test_myfunc(xp):
173+ a = xp.asarray([1, 2])
174+ b = mymodule.myfunc(a) # This is jitted when xp=jax.numpy
175+ c = naked.myfunc(a) # This is not
137176 """
138177 tags = {
139178 "allow_dask_compute" : allow_dask_compute ,
@@ -154,11 +193,13 @@ def patch_lazy_xp_functions(
154193 Test lazy execution of functions tagged with :func:`lazy_xp_function`.
155194
156195 If ``xp==jax.numpy``, search for all functions which have been tagged with
157- :func:`lazy_xp_function` in the globals of the module that defines the current test
196+ :func:`lazy_xp_function` in the globals of the module that defines the current test,
197+ as well as in the ``lazy_xp_modules`` list in the globals of the same module,
158198 and wrap them with :func:`jax.jit`. Unwrap them at the end of the test.
159199
160200 If ``xp==dask.array``, wrap the functions with a decorator that disables
161- ``compute()`` and ``persist()``.
201+ ``compute()`` and ``persist()`` and ensures that exceptions and warnings are raised
202+ eagerly.
162203
163204 This function should be typically called by your library's `xp` fixture that runs
164205 tests on multiple backends::
@@ -184,29 +225,33 @@ def xp(request, monkeypatch):
184225 lazy_xp_function : Tag a function to be tested on lazy backends.
185226 pytest.FixtureRequest : `request` test function parameter.
186227 """
187- globals_ = cast ("dict[str, Any]" , request .module .__dict__ ) # type: ignore[no-any-explicit]
188-
189- def iter_tagged () -> Iterator [tuple [str , Callable [..., Any ], dict [str , Any ]]]: # type: ignore[no-any-explicit]
190- for name , func in globals_ .items ():
191- tags : dict [str , Any ] | None = None # type: ignore[no-any-explicit]
192- with contextlib .suppress (AttributeError ):
193- tags = func ._lazy_xp_function # pylint: disable=protected-access
194- if tags is None :
195- with contextlib .suppress (KeyError , TypeError ):
196- tags = _ufuncs_tags [func ]
197- if tags is not None :
198- yield name , func , tags
228+ mod = cast (ModuleType , request .module )
229+ mods = [mod , * cast (list [ModuleType ], getattr (mod , "lazy_xp_modules" , []))]
230+
231+ def iter_tagged () -> ( # type: ignore[no-any-explicit]
232+ Iterator [tuple [ModuleType , str , Callable [..., Any ], dict [str , Any ]]]
233+ ):
234+ for mod in mods :
235+ for name , func in mod .__dict__ .items ():
236+ tags : dict [str , Any ] | None = None # type: ignore[no-any-explicit]
237+ with contextlib .suppress (AttributeError ):
238+ tags = func ._lazy_xp_function # pylint: disable=protected-access
239+ if tags is None :
240+ with contextlib .suppress (KeyError , TypeError ):
241+ tags = _ufuncs_tags [func ]
242+ if tags is not None :
243+ yield mod , name , func , tags
199244
200245 if is_dask_namespace (xp ):
201- for name , func , tags in iter_tagged ():
246+ for mod , name , func , tags in iter_tagged ():
202247 n = tags ["allow_dask_compute" ]
203248 wrapped = _dask_wrap (func , n )
204- monkeypatch .setitem ( globals_ , name , wrapped )
249+ monkeypatch .setattr ( mod , name , wrapped )
205250
206251 elif is_jax_namespace (xp ):
207252 import jax
208253
209- for name , func , tags in iter_tagged ():
254+ for mod , name , func , tags in iter_tagged ():
210255 if tags ["jax_jit" ]:
211256 # suppress unused-ignore to run mypy in -e lint as well as -e dev
212257 wrapped = cast ( # type: ignore[no-any-explicit]
@@ -217,7 +262,7 @@ def iter_tagged() -> Iterator[tuple[str, Callable[..., Any], dict[str, Any]]]:
217262 static_argnames = tags ["static_argnames" ],
218263 ),
219264 )
220- monkeypatch .setitem ( globals_ , name , wrapped )
265+ monkeypatch .setattr ( mod , name , wrapped )
221266
222267
223268class CountingDaskScheduler (SchedulerGetCallable ):
0 commit comments