Model partitioning when looping jnp.einsum #7609
Unanswered
JohnG-1qbit
asked this question in
Q&A
Replies: 1 comment
-
I believe this is an instance of the bug found in issue #7063. It should be fixed in #7206 . |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I have a function that applies jnp.einsum to an array in a loop, such that the array is updated after each operation. I'm interested in parallelizing this function to take advantage of model partitioning. My goal is to use model parallelism over multiple CPU devices, and the only function I'm aware of that does this is xmap (pjit model partitioning is not yet implemented for CPUs). Here is an example of what I'm trying to do:
However, one can already tell this will not work, because the named axis for
y
will becomei, k
instead ofj, k
after the first call of _einsum, then the next _einsum call cannot be executed. More precisely, I get this error:The code does run when the xmap is applied to the interior function _einsum rather than the scanned version, but I have seen this comes with a hefty overhead - and my goal is to take advantage of the speedup that comes with compiling the entire scan together. I have considered renaming the axes of y after each _einsum call, but I could not figure out how to do that in the documentation - and it seems to not be in the spirit of xmap. I would appreciate any comments or suggestions!
Beta Was this translation helpful? Give feedback.
All reactions