How to localize the distributed jax.device array? #7845
-
Let's assume a jax.DeviceArray is distributed over multiple cores, but it's small enough to be located at a core. so I want to localize it again. How can it be possible in Jax? is there a helper function to do it? Or, at least I want to localize the data into a few neighbor cores. |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments
-
The following should work. I don't know if it's the optimal way to do this regarding the number of host copies. cc @jekbradbury
|
Beta Was this translation helpful? Give feedback.
-
Note that
|
Beta Was this translation helpful? Give feedback.
-
Thanks a lot! |
Beta Was this translation helpful? Give feedback.
The following should work. I don't know if it's the optimal way to do this regarding the number of host copies. cc @jekbradbury