You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When evaluating a shard_mapped function in a multi-host environment and getting back a replicated result (something identical on all instances) I'm confused on how to use the returned value outside of a shard_mapped function. For example, in the psum example given in the docs (https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#psum) with an added jax.distributed.initialize() call at the beginning. The function f1 has the same result on all devices, however when running with multiple hosts trying to run any operation on y (like y[0]+5 or y.sum()) results in a RuntimeError which mentions "Array's that are not fully addressable". It's not clear to me why the array shouldn't be addressable when it is the same on all instances.
The end of the example states: "Notice also that because f1 returns y_block, the result of a psum over 'i', we can use out_specs=P() so the caller gets a single logical copy of the result value, rather than a tiled result". However, using out_specs=P() here results in an identical error so I'm not sure I understand the difference. The error I get mentions the with jax.spmd_mode('allow_all') context manager, which certainly works, but including this everywhere quickly makes more complex code awkward.
I'm not an expert on multi-host environments by any means so any advice is appreciated. Maybe there is some way to confirm that the separate instances have all completed the calculation and the replicated value is identical for all of them?
The full RuntimeError:
RuntimeError: Running operations on Arrays that are not fully addressable by this process (i.e. Arrays with data sharded across multiple devices and processes.) is dangerous. It’s very important that all processes run the same cross-process computations in the same order otherwise it can lead to hangs. If you’re not already familiar with JAX’s multi-process programming model, please read https://jax.readthedocs.io/en/latest/multi_process.html. To fix this error, run your jitted computation inside with jax.spmd_mode('allow_all'): context manager.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
When evaluating a shard_mapped function in a multi-host environment and getting back a replicated result (something identical on all instances) I'm confused on how to use the returned value outside of a shard_mapped function. For example, in the psum example given in the docs (https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#psum) with an added
jax.distributed.initialize()
call at the beginning. The functionf1
has the same result on all devices, however when running with multiple hosts trying to run any operation ony
(likey[0]+5
ory.sum()
) results in a RuntimeError which mentions "Array's that are not fully addressable". It's not clear to me why the array shouldn't be addressable when it is the same on all instances.The end of the example states: "Notice also that because
f1
returnsy_block
, the result of apsum
over'i'
, we can useout_specs=P()
so the caller gets a single logical copy of the result value, rather than a tiled result". However, using out_specs=P() here results in an identical error so I'm not sure I understand the difference. The error I get mentions thewith jax.spmd_mode('allow_all')
context manager, which certainly works, but including this everywhere quickly makes more complex code awkward.I'm not an expert on multi-host environments by any means so any advice is appreciated. Maybe there is some way to confirm that the separate instances have all completed the calculation and the replicated value is identical for all of them?
The full RuntimeError:
RuntimeError: Running operations on
Array
s that are not fully addressable by this process (i.e.Array
s with data sharded across multiple devices and processes.) is dangerous. It’s very important that all processes run the same cross-process computations in the same order otherwise it can lead to hangs. If you’re not already familiar with JAX’s multi-process programming model, please read https://jax.readthedocs.io/en/latest/multi_process.html. To fix this error, run yourjitted
computation insidewith jax.spmd_mode('allow_all'):
context manager.Beta Was this translation helpful? Give feedback.
All reactions