2323
2424 NumPyObject : TypeAlias = np .ndarray [Any , Any ] | np .generic # type: ignore[no-any-explicit]
2525 P = ParamSpec ("P" )
26+ else :
27+ # Sphinx hacks
28+ NumPyObject = Any
29+
30+ class P : # pylint: disable=missing-class-docstring
31+ args : tuple
32+ kwargs : dict
2633
2734
2835@overload
@@ -47,7 +54,7 @@ def apply_numpy_func( # type: ignore[valid-type]
4754) -> tuple [Array , ...]: ... # numpydoc ignore=GL08
4855
4956
50- def apply_numpy_func ( # type: ignore[valid-type]
57+ def apply_numpy_func ( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
5158 func : Callable [P , NumPyObject | Sequence [NumPyObject ]],
5259 * args : Array ,
5360 shape : tuple [int , ...] | Sequence [tuple [int , ...]] | None = None ,
@@ -69,7 +76,7 @@ def apply_numpy_func( # type: ignore[valid-type]
6976 as depending on the backend it may be executed more than once.
7077 *args : Array
7178 One or more Array API compliant arrays. You need to be able to apply
72- ``np .asarray()` ` to them to convert them to numpy; read notes below about
79+ :func:`numpy .asarray` to them to convert them to numpy; read notes below about
7380 specific backends.
7481 shape : tuple[int, ...] | Sequence[tuple[int, ...]], optional
7582 Output shape or sequence of output shapes, one for each output of `func`.
@@ -97,25 +104,23 @@ def apply_numpy_func( # type: ignore[valid-type]
97104 This allows applying eager functions to jitted JAX arrays, which are lazy.
98105 The function won't be applied until the JAX array is materialized.
99106
100- The `JAX transfer guard
101- <https://jax.readthedocs.io/en/latest/transfer_guard.html>`_
102- may prevent arrays on a GPU device from being transferred back to CPU.
103- This is treated as an implicit transfer.
107+ The :doc:`jax:transfer_guard` may prevent arrays on a GPU device from being
108+ transferred back to CPU. This is treated as an implicit transfer.
104109
105110 PyTorch, CuPy
106111 These backends raise by default if you attempt to convert arrays on a GPU device
107112 to NumPy.
108113
109114 Sparse
110- By default, sparse prevents implicit densification through ``np.asarray`.
111- `This safety mechanism can be disabled
115+ By default, sparse prevents implicit densification through
116+ :func:`numpy.asarray`. `This safety mechanism can be disabled
112117 <https://sparse.pydata.org/en/stable/operations.html#package-configuration>`_.
113118
114119 Dask
115120 This allows applying eager functions to dask arrays.
116121 The dask graph won't be computed.
117122
118- `apply_numpy_func` doesn't know if `func` reduces along any axes and shape
123+ `apply_numpy_func` doesn't know if `func` reduces along any axes; also, shape
119124 changes are non-trivial in chunked Dask arrays. For these reasons, all inputs
120125 will be rechunked into a single chunk.
121126
@@ -125,9 +130,19 @@ def apply_numpy_func( # type: ignore[valid-type]
125130
126131 The outputs will also be returned as a single chunk and you should consider
127132 rechunking them into smaller chunks afterwards.
133+
128134 If you want to distribute the calculation across multiple workers, you
129- should use `dask.array.map_blocks`, `dask.array.blockwise`,
130- `dask.array.map_overlap`, or a native Dask wrapper instead of this function.
135+ should use :func:`dask.array.map_blocks`, :func:`dask.array.map_overlap`,
136+ :func:`dask.array.blockwise`, or a native Dask wrapper instead of
137+ `apply_numpy_func`.
138+
139+ See Also
140+ --------
141+ jax.transfer_guard
142+ jax.pure_callback
143+ dask.array.map_blocks
144+ dask.array.map_overlap
145+ dask.array.blockwise
131146 """
132147 if xp is None :
133148 xp = array_namespace (* args )
@@ -239,8 +254,8 @@ def _npfunc_wrapper( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT
239254
240255 Any keyword arguments are passed through verbatim to the wrapped function.
241256
242- Raise if np.asarray() raises on any input. This typically happens if the input is
243- lazy and has a guard against being implicitly turned into a NumPy array (e.g.
257+ Raise if np.asarray raises on any input. This typically happens if the input is lazy
258+ and has a guard against being implicitly turned into a NumPy array (e.g.
244259 densification for sparse arrays, device->host transfer for cupy and torch arrays).
245260 """
246261
0 commit comments