Skip to content
Discussion options

You must be logged in to vote

Unfortunately, this kind of approach will not work. PyTree flattening needs to return the same number of leaves regardless of the type of the leaves: dynamically changing the number of leaves based on their Python type breaks the contract of PyTree flattening that is relied upon by transformations like jit. See #16170 for a related discussion.

Your best way forward here is probably to abandon the notion of relying on isinstance checks during flattening. One alternative that might work if done carefully would be to determine at __init__ time what the flattening will look like, and store that information for use when tree_flatten is called, and then be sure to avoid this __init__ logic during

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@Impure-King
Comment options

@jakevdp
Comment options

Answer selected by Impure-King
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants