All input arrays must have the same shape #13064
Unanswered
PoulomiPradhan
asked this question in
Q&A
Replies: 1 comment 9 replies
-
JAX does not support ragged arrays. So for example, this will work, because all entries in the array are the same size: import jax.numpy as jnp
y = 1.
dist = 2.
pt = jnp.array([0., y, dist])
print(pt)
# [0. 1. 2.] This, however, will raise an error because you're trying to construct a 2D array with rows of different sizes: y = jnp.arange(4)
dist = jnp.ones(4)
pt = jnp.array([0., y, dist])
# ValueError: All input arrays must have the same shape. To fix this, you can do a couple things depending on what you were hoping the output would be. For example, if you wanted a row of zeros, in your output array, you could do this: pt = jnp.array([jnp.zeros_like(y), y, dist])
print(pt)
# [[0. 0. 0. 0.]
# [0. 1. 2. 3.]
# [1. 1. 1. 1.]] Hope that helps! |
Beta Was this translation helpful? Give feedback.
9 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
pt = jnp.array([0.,y,dist]) this is throwing ValueError "All input arrays must have the same shape" y is of type class <jax.interpreters.ad.JVPTracer> and dist is of type float.
out = stack([asarray(elt,dtype=dtype)]) in lax_numpy.py (line number 1900) ) is causing this.
How can I overcome this error ?
Beta Was this translation helpful? Give feedback.
All reactions