Skip to content
Discussion options

You must be logged in to vote

The raw get_shape function takes an array and returns a static shape. If you jit-compile, though, it will return non-static traced variables. You can see it by adding some print statements:

import jax
import jax.numpy as jnp

def get_shape(data):
    return jnp.shape(data)

@jax.jit
def ones_like(data):
  print("shape:", get_shape(data))
  print("jit(shape):", jax.jit(get_shape)(data))
  return jnp.ones(get_shape(data))

ones_like(jnp.arange(3.))

Output:

shape: (3,)
jit(shape): (Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>,)

The output of shape is a tuple of static ints, so you can use it to define the shape of a new array.

The output of jit(shape) is a …

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@user799595
Comment options

Answer selected by user799595
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants