jax._src.flatten_util - raveling and unraveling. #19355
Replies: 1 comment 3 replies
-
Can you give more detail on what precisely you're trying to achieve? There may be public APIs to do what you have in mind. I'm having trouble understanding what "ravel a parameter - pytree into a single 1D array and rebuild it" means. |
Beta Was this translation helpful? Give feedback.
3 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi I am trying to ravel a parameter - pytree into a single 1D array and rebuild it. Found the flatten_util module helpful for the raveling part. But to unravel we obviously need the tree_def and unravel_list (which I am assuming is being used to split and reshape into leaves). How do I get these from the HashablePartial object produced by the ravel_pytree function? Also is there a clean way to vmap both of these for batches?
Beta Was this translation helpful? Give feedback.
All reactions