Error on jax.numpy.tile when reps is a jax.numpy.array #21740
-
Not sure if this is a bug or WAI. Consider the following 2 code examples that uses jnp.tile:
Both of these work with no errors. Now consider the following modification:
This fails with the following error message:
Should jnp.tile fail in this instance? I am looking at the readthedoc https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.tile.html Thank you! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
This is expected: the We need to update the docs here; |
Beta Was this translation helpful? Give feedback.
This is expected: the
reps
argument must be a sequence of static integers (because the size of the output array is directly related to the values inreps
). Within a JIT context, all JAX arrays are non-static, so you wouldn't be able to pass a JAX array toreps
.We need to update the docs here;
reps
is notarray_like
, but rather a sequence of integers. We'll get to that as part of #21461, but haven't donejax.numpy.tile
yet.