Skip to content
Discussion options

You must be logged in to vote

I think the recommended way to do this would be something like

my_func_jac = jax.jacobian(lambda x, y, z: my_func(x, y, z)[0])

If you jit-compile the function, the XLA compiler will do dead code elimination to ensure that unused results are not computed.

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@Justin-Tan
Comment options

@jakevdp
Comment options

Answer selected by Justin-Tan
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants