@@ -87,8 +87,8 @@ def lazy_xp_function( # type: ignore[no-any-explicit]
8787 graph.
8888 jax_jit : bool, optional
8989 Set to True to replace `func` with ``jax.jit(func)`` after calling the
90- :func:`patch_lazy_xp_functions` test helper with ``xp=jax.numpy``. Set to False if
91- `func` is only compatible with eager (non-jitted) JAX. Default: True.
90+ :func:`patch_lazy_xp_functions` test helper with ``xp=jax.numpy``. Set to False
91+ if `func` is only compatible with eager (non-jitted) JAX. Default: True.
9292 static_argnums : int | Sequence[int], optional
9393 Passed to jax.jit. Positional arguments to treat as static (compile-time
9494 constant). Default: infer from `static_argnames` using
@@ -113,7 +113,7 @@ def lazy_xp_function( # type: ignore[no-any-explicit]
113113 def test_myfunc(xp):
114114 a = xp.asarray([1, 2])
115115 # When xp=jax.numpy, this is the same as `b = jax.jit(myfunc)(a)`
116- # When xp=dask.array, crash on compute() or persist()
116+ # When xp=dask.array, crash on compute() or persist()
117117 b = myfunc(a)
118118
119119 Notes
@@ -150,8 +150,8 @@ def patch_lazy_xp_functions(
150150 :func:`lazy_xp_function` in the globals of the module that defines the current test
151151 and wrap them with :func:`jax.jit`. Unwrap them at the end of the test.
152152
153- If ``xp==dask.array``, wrap the functions with a decorator that disables ``compute()``
154- and ``persist()``.
153+ If ``xp==dask.array``, wrap the functions with a decorator that disables
154+ ``compute()`` and ``persist()``.
155155
156156 This function should be typically called by your library's `xp` fixture that runs
157157 tests on multiple backends::
0 commit comments