You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hey, I have a function f(r) that takes a 3-d vector r as input and internally, the function calculates the norm of the vector. I want to calculate the n-th derivative of f with respect to the norm |r| for very small norms |r| = 1e-5, where n can be up to 6. My first attempt was to write
f_with_norm_dependency = lambda unit_vec, norm: f(unit_vec * norm)
and then get the derivative by f_with_norm_dependency_grad = jax.jacfwd(f_with_norm_dependency, argnums=1)
however, this is numerically unstable for higher-order derivatives. The reason seems to be that the norm function itself becomes unstable in this case, so
is unstable for small norms. Obviously, the first derivative is always 1, and all higher-order derivatives should always be 0. Therefore, I tried to write a custom gradient where I check if the tangent vector is colinear to the primal vector, and if it is the case, simply return the constant 1.0 like this:
This works as long as I only use jax.jacfwd, but it breaks as soon as I try to combine it with jax.jacrev. I need to use both because the function takes a list of many different those r as input, and in another part of the code (a regularization loss), I need to calculate the Laplacian with respect to those r, which means I need to calculate first reverse than forward jacobians of f with respect to r itself, not |r|.
I tried to write a custom vjp for the jvp, but didn't get anything to work. Do you have an idea how this could be done?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hey, I have a function
f(r)
that takes a 3-d vectorr
as input and internally, the function calculates the norm of the vector. I want to calculate then
-th derivative off
with respect to the norm|r|
for very small norms|r| = 1e-5
, wheren
can be up to 6. My first attempt was to writef_with_norm_dependency = lambda unit_vec, norm: f(unit_vec * norm)
and then get the derivative by
f_with_norm_dependency_grad = jax.jacfwd(f_with_norm_dependency, argnums=1)
however, this is numerically unstable for higher-order derivatives. The reason seems to be that the norm function itself becomes unstable in this case, so
is unstable for small norms. Obviously, the first derivative is always 1, and all higher-order derivatives should always be 0. Therefore, I tried to write a custom gradient where I check if the tangent vector is colinear to the primal vector, and if it is the case, simply return the constant 1.0 like this:
This works as long as I only use
jax.jacfwd
, but it breaks as soon as I try to combine it withjax.jacrev
. I need to use both because the function takes a list of many different thoser
as input, and in another part of the code (a regularization loss), I need to calculate the Laplacian with respect to thoser
, which means I need to calculate first reverse than forward jacobians off
with respect tor
itself, not|r|
.I tried to write a custom
vjp
for thejvp
, but didn't get anything to work. Do you have an idea how this could be done?Beta Was this translation helpful? Give feedback.
All reactions