Skip to content
Discussion options

You must be logged in to vote

The gradient with respect to an indexing operation where the index is a float converted to an int is similar to the gradient with respect to floor: it's always zero (because infinitesimally changing the index cannot change the indexed value):

import jax
import jax.numpy as jnp

def f(x, i):
  return x[i.astype(int)]

x = jnp.arange(4.0)
i = 1.0
print(jax.grad(f, argnums=1)(x, i))
# 0.0

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@jakubMitura14
Comment options

@jakevdp
Comment options

Answer selected by jakubMitura14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants