Converting Pytree params to a vector #9176
-
Hi, I was wondering how to convert a pytree of params to form the a single vector with all the parameters flattened, and then mapping this vector back to the pytree of params.
Thanks for your help. |
Beta Was this translation helpful? Give feedback.
Answered by
PhilipVinc
Jan 12, 2022
Replies: 1 comment 3 replies
-
|
Beta Was this translation helpful? Give feedback.
3 replies
Answer selected by
dptam
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
jax.flatten_util.tree_ravel
does exactly this. It returns the flattened vector + a function to unravel (or unflatten) the vector.