Skip to content
Discussion options

You must be logged in to vote

Here is where the autodiff rules are defined for the real_p and imag_p primitives in JAX: https://github.com/google/jax/blob/292deef6fda9f639fcecd9883a1112825a1eb54f/jax/_src/lax/lax.py#L1890-L1896

They use ad.deflinear2, which is a utility for registering jvp and transpose rules for linear primitives: https://github.com/google/jax/blob/292deef6fda9f639fcecd9883a1112825a1eb54f/jax/_src/interpreters/ad.py#L517-L519

For more general information regarding autodiff in JAX, some good resources from the docs are

And if you'…

Replies: 2 comments 2 replies

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
2 replies
@mmmeee1111
Comment options

@jakevdp
Comment options

Answer selected by mmmeee1111
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants