Advice on debugging vmap/grad interaction #7382
-
I'm posting this here instead of as an issue since I haven't been able to make a minimal example which reproduces my problem. Hopefully I can get some advice on how to identify the specific issue. The problem I'm having is that the gradient of a vmapped function is giving me zeros, even when only passing a batch dimension of size 1 to the vmapped function: # Context: f calls a haiku model and indexes into the output
# Params is a pytree of params
# y is a list of arrays
# Call f and take the mean of the output
def single_element_fn(params, y):
loss = f(params, y)
return jnp.mean(loss)
# Same thing, but vmap f w.r.t. the second argument
def batch_fn(params, y):
loss = jax.vmap(f, (None, 0))(params, y)
return jnp.mean(loss)
# Get grad and output for non-vmapped function
grad_fn = jax.grad(single_element_fn, argnums=1)
fn_output = single_element_fn(params, y)
grad_output = grad_fn(params, y)
# Add a size 1 dim to y
batch_y = jax.tree_map(lambda arr: arr[None], y)
# Get grad and output for vmapped function
batch_fn_output = batch_fn(params, y)
batch_grad_fn = jax.grad(batch_fn, argnums=1)
batch_grad_output = batch_grad_fn(params, batch_y)
print('Non-vmap value:', fn_output)
print('Non-vmap grad:', grad_output[0][0, 0, 0, 0])
print('Vmap value:', batch_fn_output)
print('Vmap grad:', batch_grad_output[0][0, 0, 0, 0, 0])
# stdout:
# Non-vmap value: -63.325596
# Non-vmap grad: -0.0038595283
# Vmap value: -63.325596
# Vmap grad: 0.0 Is this definitely a bug? Or is there some function Edit: I did try this with a very simple function in place of the model I'm using, but in that case grad behaves as expected. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
Does your model use batch normalization? In my experience this can have some weird interactions with vmap due to the case of a batch size of 1 always outputting 0, with zero gradient. |
Beta Was this translation helpful? Give feedback.
-
It was very hard to find, but I ended up having a |
Beta Was this translation helpful? Give feedback.
It was very hard to find, but I ended up having a
dynamic_slice_in_dim
which was going out of bounds, leading to the behavior I posted.