@@ -39,7 +39,8 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
3939 ** kwargs : Any ,
4040) -> tuple [Array , ...]:
4141 """
42- Apply a function that operates on NumPY arrays to any Array API compliant arrays.
42+ Apply a function that operates on NumPY arrays to any Array API compliant arrays,
43+ as long as you can apply ``np.asarray`` to them.
4344
4445 Parameters
4546 ----------
@@ -103,7 +104,7 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
103104
104105 Notes
105106 -----
106- JAX:
107+ JAX
107108 This allows applying eager functions to jitted JAX arrays, which are lazy.
108109 The function won't be applied until the JAX array is materialized.
109110
@@ -112,16 +113,21 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
112113 may prevent arrays on a GPU device from being transferred back to CPU.
113114 This is treated as an implicit transfer.
114115
115- PyTorch, CuPy:
116+ PyTorch, CuPy
116117 These backends raise by default if you attempt to convert arrays on a GPU device
117118 to NumPy.
118119
119- Dask:
120- This function allows applying func to the chunks of dask arrays.
120+ Sparse
121+ By default, sparse prevents implicit densification through ``np.asarray`.
122+ `This safety mechanism can be disabled
123+ <https://sparse.pydata.org/en/stable/operations.html#package-configuration>`_.
124+
125+ Dask
126+ This allows applying eager functions to the individual chunks of dask arrays.
121127 The dask graph won't be computed. As a special limitation, `func` must return
122128 exactly one output.
123129
124- In order to allow Dask you need to specify at least
130+ In order to enable running on Dask you need to specify at least
125131 `input_indices`, `output_indices`, and `core_indices`, but you may also need
126132 `adjust_chunks` and `new_axes` depending on the function.
127133
@@ -147,7 +153,13 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
147153 ... core_indices='i')
148154
149155 This will cause `apply_numpy_func` to raise if the first axis of `x` is broken
150- along multiple chunks.
156+ along multiple chunks, thus forcing the final user to rechunk ahead of time:
157+
158+ >>> x = x.chunk({0: -1})
159+
160+ This needs to always be a conscious decision on behalf of the final user, as the
161+ new chunks will be larger than the old and may cause memory issues, unless chunk
162+ size is reduced along a different, non-core axis.
151163 """
152164 if xp is None :
153165 xp = array_namespace (* args )
0 commit comments