Replies: 1 comment 1 reply
-
The two versions are subtley different in their order of operations, and this leads to different temporary variables. It might help to write out the Python intermediate variables explicitly:
Note that while Does that make sense? |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
In the following example, the jaxprs for
f1
andf2
look the same but their peak memory usages are different. Could you explain what happened?Furthermore, is it possible to reduce the memory usage to ~4008192 bytes which corresponds to a float32 array of size (1000, 1000), and in the same time keep the function gradable and efficient?
Thanks!
Beta Was this translation helpful? Give feedback.
All reactions