Replies: 3 comments
-
The problem is your use of So in this case you have Given that the only meaningful tangent for a nondifferentiable type is zero, then for such any integer-dtyped-outputs I think it should suffice for you to change your returned tangent to: return jax.custom_derivatives.SymbolicZero.from_primal_value(x) Or if you need cross-JAX-version compatibility then you should be able to instead use: aval = jax.core.raise_to_shaped(jax.core.get_aval(x))
if hasattr(aval, "to_tangent_aval"):
aval = aval.to_tangent_aval()
return jax.custom_derivatives.SymbolicZero(aval) |
Beta Was this translation helpful? Give feedback.
-
@patrick-kidger Thank you very much for the quick response! I think my previous example was misleading, because I checked my original function and it can never return an integer. The only way I could create a similar error message with a minimal function was adding those type arguments... The other function that is called inside the main function with I wanted to reply to explain my mistake in the previous example. |
Beta Was this translation helpful? Give feedback.
-
Closing in favor of #24295. Thanks again for your help @patrick-kidger ! |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello!
I have a function that uses
custom_jvp
withnondiff_argnums
andjit
withstatic_argnums
. The function has the derivative order as one of the inputs and the derivative with respect to the second argument is 0. A simplified version of this function is as follows,This is called inside a much bigger function which I want to take the Jacobian of. Until JAX version
0.4.33
this was working fine, however with0.4.34
release, I started to get the following error,Note: If you just call the jacobian of the
inner
, the error is not generated. So, the above code is just to showcustom_jvp
part.First, I was suspecting the second argument (
a
) was causing the problem since its derivative is set to 0 in a hacky way. But even though I changed it to benondiff_arg
, I still get the same error.I tried really hard to create a minimal reproducer, but I couldn't recreate it with a simple case. I am sorry for asking the question in this way, but can you think of anything that might cause the problem with the new release?
I believe patrick-kidger/equinox#871 is related to this, but this package is using an even more custom version of
custom_jvp
, so the implementation to my case was not straight-forward :DThanks!
Beta Was this translation helpful? Give feedback.
All reactions