Zero class #8112
-
I'm implementing a class
I guess a lot of people have already thought about this, is there already a solution? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
Hi - I don't think there's any easy way to do this – JAX doesn't provide any hooks within One fairly heavy-weight way you might approach a problem like this is by implementing a custom JAX transformation that would convert a JAX function into an equivalent function that accepts your custom |
Beta Was this translation helpful? Give feedback.
Hi - I don't think there's any easy way to do this – JAX doesn't provide any hooks within
jax.numpy
operations to override behavior on custom array-like objects. Regardingad_util.Zero
: it serves a different purposel; it is a placeholder for gradients of integers within automatic differentiation, and is not really handled outside the context of autodiff.One fairly heavy-weight way you might approach a problem like this is by implementing a custom JAX transformation that would convert a JAX function into an equivalent function that accepts your custom
Zero
class. This would require definingZero
-aware implementations of each of JAX's primitive operations.