You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
For more context of this question, it actually extends from another discussion happened here:
I noticed unlike other function transformations (e.g. jax.vmap, jax.jit), jax.grad is traced by concrete tracer because it has to handle the case of value-based control flow, e.g.:
import jax
def f(x):
if x < 0:
return 0.0
return x ** 2
print(jax.grad(f)(1.0))
Meanwhile, I learned from the above discussion that jaxpr has to be converted by abstract tracer. It leads me to wonder:
whether jax.grad transformation involves jaxpr?
if so, how jaxpr is produced in this case given jax.grad is traced by a concrete tracer instead of an abstract tracer?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
For more context of this question, it actually extends from another discussion happened here:
I noticed unlike other function transformations (e.g.
jax.vmap
,jax.jit
),jax.grad
is traced by concrete tracer because it has to handle the case of value-based control flow, e.g.:Meanwhile, I learned from the above discussion that jaxpr has to be converted by abstract tracer. It leads me to wonder:
jax.grad
transformation involves jaxpr?jax.grad
is traced by a concrete tracer instead of an abstract tracer?Beta Was this translation helpful? Give feedback.
All reactions