You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm prototyping an all-gather kernel using Pallas with shard_map on a 2D mesh. When using make_async_remote_copy with DeviceIdType.MESH, interpret mode raises:
NotImplementedError: Meshes with more than 1 named dimension not implemented in dma_start_p
raiseNotImplementedError("Meshes with more than 1 named dimension not "
"implemented in dma_start_p")
The same code runs correctly on the TPU v5e-4, and I wanted to see if there was a common workaround here that doesn't involve flattening the mesh axes.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Hey all,
I'm prototyping an all-gather kernel using Pallas with shard_map on a 2D mesh. When using
make_async_remote_copywithDeviceIdType.MESH, interpret mode raises:NotImplementedError: Meshes with more than 1 named dimension not implemented in dma_start_pjax/jax/_src/pallas/mosaic/primitives.py
Lines 619 to 621 in fc1b321
The same code runs correctly on the TPU v5e-4, and I wanted to see if there was a common workaround here that doesn't involve flattening the mesh axes.
Reproducer:
It looks like the TPU lowering path linearizes ND mesh coordinates to a scalar logical device ID
jax/jax/_src/pallas/primitives.py
Lines 1431 to 1478 in fc1b321
and interpret/utils.py contains device_coords_to_logical_id() with equivalent logic
jax/jax/_src/pallas/mosaic/interpret/utils.py
Lines 206 to 220 in fc1b321
Questions:
NotImplementedErrorwith linearization logic mirroring the TPU path?Beta Was this translation helpful? Give feedback.
All reactions