lax function greater than return zero gradient ? #8977
Replies: 1 comment 3 replies
-
Zero is the correct gradient here. Remember that what differentiation is measuring the change in from jax import grad
def f(x):
return (x > 0).astype(float)
grad(f)(1.0)
# 0.0 For Similarly, for all values of That leaves us with the tricky case of JAX and other autodiff systems tend to handle discontinuities in this way: if the positive gradient and negative gradient disagree, but one is defined and the other is not, we use the one that is defined. Under this definition of the gradient, mathematically and numerically the gradient of this function is everywhere zero. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello,
I would like to write a program that draw from a source which is a 1 dimentionnal array with weights.
I compare these weights to a uniform distribution, and it the weights are higher than the uniform distribution I
draw a sample.
My goal is to fit the drawn sample to a target.
However the gradient is zero.
and it return
Is it possible to make this code differentiable?
Best
Beta Was this translation helpful? Give feedback.
All reactions