-
I have a function (i.e. I noticed in normal call (e.g.
I wanna understand more about this discrepancy. Why in this case (i.e. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 7 replies
-
Two things to know:
import jax
def f(x):
if x < 0:
return 0.0
return x ** 2
print(jax.grad(f)(1.0))
# 2.0 If By contrast, jaxprs are created entirely based on abstract properties of arrays (namely |
Beta Was this translation helpful? Give feedback.
Two things to know:
ShapedArray
tracers basically contain references to the array'sshape
anddtype
, whileConcreteArray
tracers contain references to the array'sshape
,dtype
, as well as its contents.jax.grad
uses concrete tracers in order to support value-based control flow during tracing. For example:If
grad
did not use concrete tracers in this case, the result of the conditionx < 0
would not be knowable at trace-time, and we could not differentiate functions that use this kind of control flow.By contrast, jaxprs are created entirely based on abstract properties of arrays (namely
shape
and