Skip to content

Commit 2cc6522

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Add a TPU platform check in make_mesh in the slice_index check because GPUs also have a slice_index attribute (mind-blown)
PiperOrigin-RevId: 833504098
1 parent 94b09dc commit 2cc6522

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

jax/_src/sharding_impls.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1191,7 +1191,8 @@ def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str],
11911191
mesh_devices = mesh_utils.create_device_mesh(
11921192
new_axis_shapes, devices,
11931193
allow_split_physical_axes=allow_split_physical_axes)
1194-
if (hasattr(mesh_devices.flat[0], 'slice_index') and
1194+
first_d = mesh_devices.flat[0]
1195+
if (first_d.platform == 'tpu' and hasattr(first_d, 'slice_index') and
11951196
len({d.slice_index for d in mesh_devices.flat}) > 1):
11961197
raise ValueError(
11971198
'`jax.make_mesh` does not support multi-slice topologies. Please use'

0 commit comments

Comments
 (0)