static vs traced variables 101 #12553
-
Dear JAX community My actual code is quite a bit more complicated, but I've boiled it down to a simple example: def get_shape(data):
return jnp.shape(data)
@jax.jit
def ones_like(data):
return jnp.ones(get_shape(data))
ones_like(np.arange(3.)) For the life of me I cannot figure out why this code doesn't work if I jit I have read up on static vs. traced variables, but I'm not sure how those concepts apply to this example. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
The raw import jax
import jax.numpy as jnp
def get_shape(data):
return jnp.shape(data)
@jax.jit
def ones_like(data):
print("shape:", get_shape(data))
print("jit(shape):", jax.jit(get_shape)(data))
return jnp.ones(get_shape(data))
ones_like(jnp.arange(3.)) Output:
The output of The output of Does that answer your question? |
Beta Was this translation helpful? Give feedback.
The raw
get_shape
function takes an array and returns a static shape. If you jit-compile, though, it will return non-static traced variables. You can see it by adding some print statements:Output:
The output of
shape
is a tuple of static ints, so you can use it to define the shape of a new array.The output of
jit(shape)
is a …