Skip to content
Discussion options

You must be logged in to vote

OK! Now I got it right! 😄
This is the story: Under the hood the jax wraps method (of utils.py) adds an additional element to the PyTreeDef. When accounting for that my code runs without errors (yet to determine if the math is preserved too).
Thanks @skye @avital @hawkinsp @marcvanzee

Replies: 3 comments 18 replies

Comment options

You must be logged in to vote
16 replies
@mattiasmar
Comment options

@skye
Comment options

skye Dec 8, 2021
Maintainer

@skye
Comment options

skye Dec 8, 2021
Maintainer

@mattiasmar
Comment options

@mattiasmar
Comment options

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
2 replies
@skye
Comment options

skye Dec 10, 2021
Maintainer

@mattiasmar
Comment options

Answer selected by mattiasmar
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
5 participants