Skip to content
Discussion options

You must be logged in to vote

There is no easy built-in way to do this, but generalizing argnums to handle arbitrary pytrees is something that's been frequently discussed. See #3875, #10614, and references within.

I think the solutions suggested in the other answers here are probably the best option in the current version of JAX.

Replies: 4 comments 3 replies

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
1 reply
@eguiraud
Comment options

Comment options

You must be logged in to vote
2 replies
@eguiraud
Comment options

@yikuanli
Comment options

Answer selected by jakevdp
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
5 participants