Skip to content
Discussion options

You must be logged in to vote

Regarding the variable being a JVPTracer when you print it within the value_and_grad transform: this is expected behavior. JAX uses tracers as standins for DeviceArray objects when transforming functions with jit, vmap, grad, and other transforms. Take a look at How to think in JAX for some background on this.

Regarding the IndexError: it sounds like you're attempting to index a scalar value within the vmap expression, but it's difficult to tell why because your code snippet doesn't show how you're calling the functions. Here's a simpler example of how that can happen:

import jax
import jax.numpy as jnp

@jax.vmap
def f(x):
  print(f"type(x) = {type(x)}")
  print(f"x.shape = {x.shape}")
  r…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@Cai-fx
Comment options

Answer selected by Cai-fx
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants