Does jax.value_and_grad
imply jit?
#16678
-
From my understanding, |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
Hi. As far as I understand, this is not the case. To be able to calculate the value and gradient of a function, all you need is a 'recipe' for differentiation, which is either inferred from the computational graph or defined with a custom_vjp. I attached an example of a function which cannot be jitted because it uses calls to original numpy, but can be differentiated because a custom_vjp is defined:
I hope this illustrates the point that you don't need to be able to jit a function to find its gradients. An important point to mention here is that the function will be compiled no matter if you use JIT or not. However, If you do use JIT (just in time compilation), you can benefit from the XLA optimizations. |
Beta Was this translation helpful? Give feedback.
-
There's a more concise way to see that import jax
def relu(x):
return 0.0 if x < 0 else x
print(jax.grad(relu)(1.0))
# 1.0
print(jax.jit(relu)(1.0))
# ConcretizationTypeError: Abstract tracer value encountered... |
Beta Was this translation helpful? Give feedback.
Hi. As far as I understand, this is not the case. To be able to calculate the value and gradient of a function, all you need is a 'recipe' for differentiation, which is either inferred from the computational graph or defined with a custom_vjp. I attached an example of a function which cannot be jitted because it uses calls to original numpy, but can be differentiated because a custom_vjp is defined: