XlaRuntimeError: PyTree custom to_iterable function should return a pair #16118
Replies: 1 comment
-
Thanks for the question! But IIUC this is a copy of #16119, so maybe we should keep the discussion on that thread. To do that, I'll close this thread. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I’m having trouble escaping the same error after several revisions of the code, using a custom pytree. I have a pytree with four nodes. The first node is what my predict functions are differentiated with respect to. The remaining three nodes are auxiliary which I’ve now made explicit, yet the same error relentlessly surfaces with respect to the third node whenever I try to create an array of pytrees from an array of the third nodes (protein_arguments_tuple, which is a tuple of integers all of equal length). It does work with a single pytree made with a single protein_arguments_tuple as the third node. I abandoned the array in favour of a list of pytrees and a for loop which works but I am just kicking the can down the road because when I try to map my predict function to the list of pytrees, I get the exact same error. The error appears to be with the flatten_customPyTree ?
Error: “…. File "/opt/anaconda3/lib/python3.8/site-packages/jax/_src/tree_util.py", line 58, in tree_flatten
return pytree.flatten(tree, is_leaf)
XlaRuntimeError: PyTree custom to_iterable function should return a pair”
I also tried giving it ‘a pair’ in the form of a Tuple, None and Tuple, False to no avail.
I seem to get things to work with slow for loops. The idea here is with no grad in use, and the variables held constant (a single epoch), the predict function is to be applied to the number of data points; hence, the list of protein_arguments_tuple which I couldn’t make into in either a JAX array or numpy array or put into an earlier single function without creating arrays or lists.
…..
Thanks for looking,
Tom
Beta Was this translation helpful? Give feedback.
All reactions