Skip to content
Discussion options

You must be logged in to vote

In general, converting concrete traced values to non-traced Python scalars or numpy arrays will effectively drop those values from any automatic differentiation. We can see this by comparing the output of a simple function to the same function in which we make use of jax internals to cast the input to a Python float.

Here's a simple auto-diffed function:

def f(x):
  return 2 * x
print(jax.grad(f)(1.0))
# 2.0

And here's what happens if we cast the traced input to a Python float (which requires using some JAX internals, as it's not recommended):

def f(x):
  x_float = float(jax.core.get_aval(x).val)
  return 2.0 * x_float
print(jax.grad(f)(1.0))
# 0.0

The result is incorrect, because we've u…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by jacobusmmsmit
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants