Skip to content
Discussion options

You must be logged in to vote

The out axes specification should have the same tree structure as the outputs. This means that if you're returning a GraphsTuple, then the out_axes specification should also be a GraphsTuple, so something like this:

out_axes = jraph.GraphsTuple(0, None, None, None, None, None, None)

This is sometimes problematic for custom pytree implementations that do some sort of validation on their inputs; I don't know the best way to get around that issue.

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@juliandwain
Comment options

Answer selected by juliandwain
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants