Skip to content
Discussion options

You must be logged in to vote

Hi - I don't think there's any easy way to do this – JAX doesn't provide any hooks within jax.numpy operations to override behavior on custom array-like objects. Regarding ad_util.Zero: it serves a different purposel; it is a placeholder for gradients of integers within automatic differentiation, and is not really handled outside the context of autodiff.

One fairly heavy-weight way you might approach a problem like this is by implementing a custom JAX transformation that would convert a JAX function into an equivalent function that accepts your custom Zero class. This would require defining Zero-aware implementations of each of JAX's primitive operations.

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@mariogeiger
Comment options

@PhilipVinc
Comment options

@froystig
Comment options

Answer selected by mariogeiger
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
4 participants