Replies: 2 comments
-
I think what you have there is still dividing by zero but then replacing the Nan with something finite, which usually works for forward mode ad but can still cause problems with reverse mode. You might need to mask the input before dividing like mask = matrix == 0
matrix= jnp.where(mask, 0, 1/jnp.where(mask, 1, matrix)) The inner where prevents any division by zero and the outer one puts in the correct value. Still not sure why it would work for a single iteration but not for two though. |
Beta Was this translation helpful? Give feedback.
-
Thanks a lot for your help! I just tried your code and it works perfectly. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
I am trying to implement the Aberth method to find roots of a polynomial in Jax. I want a function which is differentiable and which can be jitted on GPU (to bypass #11322). Second part works, but not the first.
Below is the code I have written: it can be improved algorithmically (especially the initialization step), but for now, I want a minimal working function. The function takes the coefficients of the polymonial and computes all roots at the same time. Starting from some initial guess (here random), the function calls iteratively the function
compute_offset
which moves towards the solution. I have commented the loops (usingfor
andlax.fori_loop
) to clarify the origin of the problem.To test the gradient, I have written two functions: one parametrizing the coefficients of a polynomial in terms of some parameter
a
, another computing some scalar out of the roots. If there is a single call tocompute_offsets
in the root finding algorithm,grad
gives some number. If there are two calls as below, thengrad
givesnan
.I have used
config.update("jax_debug_nans", True)
to find the source ofnan
: this tells me that the division1 / matrix
givesnan
(sincematrix
contains zeros on the diagonal) already in the first call tocompute_offset
. What confuses me is thatgrad
still gives a number in this case. I tried to replace thenan_to_num
bymatrix.at[idx].set(1 / matrix[idx])
whereidx
contains non-diagonal indices selected withargwhere
, but it does not work either (and still get thenan
error when enabling thenan
debug). Any pointer for solving the problem is most welcome.Beta Was this translation helpful? Give feedback.
All reactions