Skip to content

Commit 2596a40

Browse files
Merge pull request jax-ml#24412 from jakevdp:fromfunction-doc
PiperOrigin-RevId: 688576529
2 parents 587832f + 7e38cbd commit 2596a40

File tree

1 file changed

+78
-1
lines changed

1 file changed

+78
-1
lines changed

jax/_src/numpy/lax_numpy.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6215,9 +6215,86 @@ def from_dlpack(x: Any, /, *, device: xc.Device | Sharding | None = None,
62156215
from jax.dlpack import from_dlpack # pylint: disable=g-import-not-at-top
62166216
return from_dlpack(x, device=device, copy=copy)
62176217

6218-
@util.implements(np.fromfunction)
6218+
62196219
def fromfunction(function: Callable[..., Array], shape: Any,
62206220
*, dtype: DTypeLike = float, **kwargs) -> Array:
6221+
"""Create an array from a function applied over indices.
6222+
6223+
JAX implementation of :func:`numpy.fromfunction`. The JAX implementation
6224+
differs in that it dispatches via :func:`jax.vmap`, and so unlike in NumPy
6225+
the function logically operates on scalar inputs, and need not explicitly
6226+
handle broadcasted inputs (See *Examples* below).
6227+
6228+
Args:
6229+
function: a function that takes *N* dynamic scalars and outputs a scalar.
6230+
shape: a length-*N* tuple of integers specifying the output shape.
6231+
dtype: optionally specify the dtype of the inputs. Defaults to floating-point.
6232+
kwargs: additional keyword arguments are passed statically to ``function``.
6233+
6234+
Returns:
6235+
An array of shape ``shape`` if ``function`` returns a scalar, or in general
6236+
a pytree of arrays with leading dimensions ``shape``, as determined by the
6237+
output of ``function``.
6238+
6239+
See also:
6240+
- :func:`jax.vmap`: the core transformation that the :func:`fromfunction`
6241+
API is built on.
6242+
6243+
Examples:
6244+
Generate a multiplication table of a given shape:
6245+
6246+
>>> jnp.fromfunction(jnp.multiply, shape=(3, 6), dtype=int)
6247+
Array([[ 0, 0, 0, 0, 0, 0],
6248+
[ 0, 1, 2, 3, 4, 5],
6249+
[ 0, 2, 4, 6, 8, 10]], dtype=int32)
6250+
6251+
When ``function`` returns a non-scalar the output will have leading
6252+
dimension of ``shape``:
6253+
6254+
>>> def f(x):
6255+
... return (x + 1) * jnp.arange(3)
6256+
>>> jnp.fromfunction(f, shape=(2,))
6257+
Array([[0., 1., 2.],
6258+
[0., 2., 4.]], dtype=float32)
6259+
6260+
``function`` may return multiple results, in which case each is mapped
6261+
independently:
6262+
6263+
>>> def f(x, y):
6264+
... return x + y, x * y
6265+
>>> x_plus_y, x_times_y = jnp.fromfunction(f, shape=(3, 5))
6266+
>>> print(x_plus_y)
6267+
[[0. 1. 2. 3. 4.]
6268+
[1. 2. 3. 4. 5.]
6269+
[2. 3. 4. 5. 6.]]
6270+
>>> print(x_times_y)
6271+
[[0. 0. 0. 0. 0.]
6272+
[0. 1. 2. 3. 4.]
6273+
[0. 2. 4. 6. 8.]]
6274+
6275+
The JAX implementation differs slightly from NumPy's implementation. In
6276+
:func:`numpy.fromfunction`, the function is expected to explicitly operate
6277+
element-wise on the full grid of input values:
6278+
6279+
>>> def f(x, y):
6280+
... print(f"{x.shape = }\\n{y.shape = }")
6281+
... return x + y
6282+
...
6283+
>>> np.fromfunction(f, (2, 3))
6284+
x.shape = (2, 3)
6285+
y.shape = (2, 3)
6286+
array([[0., 1., 2.],
6287+
[1., 2., 3.]])
6288+
6289+
In :func:`jax.numpy.fromfunction`, the function is vectorized via
6290+
:func:`jax.vmap`, and so is expected to operate on scalar values:
6291+
6292+
>>> jnp.fromfunction(f, (2, 3))
6293+
x.shape = ()
6294+
y.shape = ()
6295+
Array([[0., 1., 2.],
6296+
[1., 2., 3.]], dtype=float32)
6297+
"""
62216298
shape = core.canonicalize_shape(shape, context="shape argument of jnp.fromfunction()")
62226299
for i in range(len(shape)):
62236300
in_axes = [0 if i == j else None for j in range(len(shape))]

0 commit comments

Comments
 (0)