@@ -6215,9 +6215,86 @@ def from_dlpack(x: Any, /, *, device: xc.Device | Sharding | None = None,
62156215 from jax .dlpack import from_dlpack # pylint: disable=g-import-not-at-top
62166216 return from_dlpack (x , device = device , copy = copy )
62176217
6218- @ util . implements ( np . fromfunction )
6218+
62196219def fromfunction (function : Callable [..., Array ], shape : Any ,
62206220 * , dtype : DTypeLike = float , ** kwargs ) -> Array :
6221+ """Create an array from a function applied over indices.
6222+
6223+ JAX implementation of :func:`numpy.fromfunction`. The JAX implementation
6224+ differs in that it dispatches via :func:`jax.vmap`, and so unlike in NumPy
6225+ the function logically operates on scalar inputs, and need not explicitly
6226+ handle broadcasted inputs (See *Examples* below).
6227+
6228+ Args:
6229+ function: a function that takes *N* dynamic scalars and outputs a scalar.
6230+ shape: a length-*N* tuple of integers specifying the output shape.
6231+ dtype: optionally specify the dtype of the inputs. Defaults to floating-point.
6232+ kwargs: additional keyword arguments are passed statically to ``function``.
6233+
6234+ Returns:
6235+ An array of shape ``shape`` if ``function`` returns a scalar, or in general
6236+ a pytree of arrays with leading dimensions ``shape``, as determined by the
6237+ output of ``function``.
6238+
6239+ See also:
6240+ - :func:`jax.vmap`: the core transformation that the :func:`fromfunction`
6241+ API is built on.
6242+
6243+ Examples:
6244+ Generate a multiplication table of a given shape:
6245+
6246+ >>> jnp.fromfunction(jnp.multiply, shape=(3, 6), dtype=int)
6247+ Array([[ 0, 0, 0, 0, 0, 0],
6248+ [ 0, 1, 2, 3, 4, 5],
6249+ [ 0, 2, 4, 6, 8, 10]], dtype=int32)
6250+
6251+ When ``function`` returns a non-scalar the output will have leading
6252+ dimension of ``shape``:
6253+
6254+ >>> def f(x):
6255+ ... return (x + 1) * jnp.arange(3)
6256+ >>> jnp.fromfunction(f, shape=(2,))
6257+ Array([[0., 1., 2.],
6258+ [0., 2., 4.]], dtype=float32)
6259+
6260+ ``function`` may return multiple results, in which case each is mapped
6261+ independently:
6262+
6263+ >>> def f(x, y):
6264+ ... return x + y, x * y
6265+ >>> x_plus_y, x_times_y = jnp.fromfunction(f, shape=(3, 5))
6266+ >>> print(x_plus_y)
6267+ [[0. 1. 2. 3. 4.]
6268+ [1. 2. 3. 4. 5.]
6269+ [2. 3. 4. 5. 6.]]
6270+ >>> print(x_times_y)
6271+ [[0. 0. 0. 0. 0.]
6272+ [0. 1. 2. 3. 4.]
6273+ [0. 2. 4. 6. 8.]]
6274+
6275+ The JAX implementation differs slightly from NumPy's implementation. In
6276+ :func:`numpy.fromfunction`, the function is expected to explicitly operate
6277+ element-wise on the full grid of input values:
6278+
6279+ >>> def f(x, y):
6280+ ... print(f"{x.shape = }\\ n{y.shape = }")
6281+ ... return x + y
6282+ ...
6283+ >>> np.fromfunction(f, (2, 3))
6284+ x.shape = (2, 3)
6285+ y.shape = (2, 3)
6286+ array([[0., 1., 2.],
6287+ [1., 2., 3.]])
6288+
6289+ In :func:`jax.numpy.fromfunction`, the function is vectorized via
6290+ :func:`jax.vmap`, and so is expected to operate on scalar values:
6291+
6292+ >>> jnp.fromfunction(f, (2, 3))
6293+ x.shape = ()
6294+ y.shape = ()
6295+ Array([[0., 1., 2.],
6296+ [1., 2., 3.]], dtype=float32)
6297+ """
62216298 shape = core .canonicalize_shape (shape , context = "shape argument of jnp.fromfunction()" )
62226299 for i in range (len (shape )):
62236300 in_axes = [0 if i == j else None for j in range (len (shape ))]
0 commit comments