Example of applying xmap on serial #8649
-
Looking for an example applying xmap on the elements of a stax.serial |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 18 replies
-
This doesn't answer your question, but please note: we discourage the use of Can you say more about what you are trying to do? I can't entirely parse the question. |
Beta Was this translation helpful? Give feedback.
-
I an attempt to be even more specific, and to communicate to you better, I defined the
|
Beta Was this translation helpful? Give feedback.
-
OK! Now I got it right! 😄 |
Beta Was this translation helpful? Give feedback.
OK! Now I got it right! 😄
This is the story: Under the hood the jax
wraps
method (of utils.py) adds an additional element to the PyTreeDef. When accounting for that my code runs without errors (yet to determine if the math is preserved too).Thanks @skye @avital @hawkinsp @marcvanzee