Skip to content

Commit 456dfeb

Browse files
yashk2810Google-ML-Automation
authored andcommitted
[Take 2] Raise a better error message if anything other than a sequence of ints is passed to make_mesh or create_device_mesh
Reverts a158e02 PiperOrigin-RevId: 701045239
1 parent db158e6 commit 456dfeb

File tree

4 files changed

+48
-9
lines changed

4 files changed

+48
-9
lines changed

jax/_src/mesh_utils.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,15 @@ def _transpose_trick(
705705
*_TRANSPOSE_TRICKS[topology][mesh_shape_no_trivial_dims]
706706
)
707707

708+
def _canonicalize_axis_sizes(axis_sizes: Sequence[int]
709+
) -> tuple[int, ...] | None:
710+
new_sizes = []
711+
for s in axis_sizes:
712+
try:
713+
new_sizes.append(int(s))
714+
except:
715+
return None
716+
return tuple(new_sizes)
708717

709718
def create_device_mesh(
710719
mesh_shape: Sequence[int],
@@ -740,33 +749,41 @@ def create_device_mesh(
740749
"""
741750
if devices is None:
742751
devices = xb.devices()
743-
if np.prod(mesh_shape) != len(devices):
752+
753+
new_mesh_shape = _canonicalize_axis_sizes(mesh_shape)
754+
if new_mesh_shape is None:
755+
raise ValueError(
756+
f'`mesh_shape` passed to `create_device_mesh` should be a sequence of'
757+
f' ints. Got {mesh_shape}')
758+
del mesh_shape
759+
760+
if math.prod(new_mesh_shape) != len(devices):
744761
raise ValueError(
745762
f'Number of devices {len(devices)} must equal the product '
746-
f'of mesh_shape {mesh_shape}'
763+
f'of mesh_shape {new_mesh_shape}'
747764
)
748765
last_device = devices[-1]
749766

750767
handler = device_kind_handler_dict.get(last_device.device_kind, None)
751768
if handler is not None:
752769
result = handler(
753-
mesh_shape, devices, contiguous_submeshes=contiguous_submeshes
770+
new_mesh_shape, devices, contiguous_submeshes=contiguous_submeshes
754771
)
755772
if result is not None:
756773
return result
757774

758775
if last_device.platform == 'tpu':
759776
physical_mesh = _get_physical_tpu_mesh(devices)
760777
if contiguous_submeshes:
761-
physical_mesh = _transpose_trick(physical_mesh, mesh_shape)
778+
physical_mesh = _transpose_trick(physical_mesh, new_mesh_shape)
762779
device_mesh, _ = _create_device_mesh_for_nd_torus(
763780
physical_mesh,
764-
mesh_shape,
781+
new_mesh_shape,
765782
allow_split_physical_axes=allow_split_physical_axes,
766783
)
767784
return device_mesh
768785
else:
769-
device_mesh = np.asarray(devices).reshape(mesh_shape)
786+
device_mesh = np.asarray(devices).reshape(new_mesh_shape)
770787
return device_mesh
771788

772789

jax/_src/sharding_impls.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1714,17 +1714,25 @@ def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str],
17141714
"""
17151715
if devices is None:
17161716
devices = xla_bridge.devices()
1717-
axis_size = math.prod(axis_shapes)
1717+
new_axis_shapes = mesh_utils._canonicalize_axis_sizes(axis_shapes)
1718+
if new_axis_shapes is None:
1719+
raise ValueError(
1720+
'`axis_shapes` passed to `make_mesh` should be a sequence of ints.'
1721+
f' Got {axis_shapes}')
1722+
del axis_shapes
1723+
1724+
axis_size = math.prod(new_axis_shapes)
17181725
if axis_size > len(devices):
17191726
raise ValueError(
17201727
f'Number of devices {len(devices)} must be >= the product '
1721-
f'of mesh_shape {axis_shapes}')
1728+
f'of mesh_shape {new_axis_shapes}')
17221729
elif axis_size < len(devices):
17231730
devices = devices[:axis_size]
17241731
if devices[0].device_kind in (mesh_utils._TPU_V5_LITE, mesh_utils._TPU_V5E):
17251732
allow_split_physical_axes = True
17261733
else:
17271734
allow_split_physical_axes = False
17281735
mesh_devices = mesh_utils.create_device_mesh(
1729-
axis_shapes, devices, allow_split_physical_axes=allow_split_physical_axes)
1736+
new_axis_shapes, devices,
1737+
allow_split_physical_axes=allow_split_physical_axes)
17301738
return mesh_lib.Mesh(mesh_devices, axis_names)

tests/mesh_utils_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,12 @@ def test_create_device_mesh_for_nd_torus(
353353
)
354354
self.assertArraysEqual(assignment, expected_assignment_matrix)
355355

356+
def test_create_device_mesh_non_int_error(self):
357+
with self.assertRaisesRegex(
358+
ValueError,
359+
"`mesh_shape` passed to `create_device_mesh` should be a sequence of ints"):
360+
mesh_utils.create_device_mesh(((4,), 4))
361+
356362
@parameterized.named_parameters(
357363
('2x2x1', mock_2x2x1_devices,),
358364
('2x2x4', mock_2x2x4_devices, ),

tests/pjit_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4458,6 +4458,14 @@ def g(x):
44584458
self.assertEqual(out2.sharding, s)
44594459
self.assertEqual(out2.dtype, np.float32)
44604460

4461+
def test_make_mesh_non_int_error(self):
4462+
with self.assertRaisesRegex(
4463+
ValueError,
4464+
"`axis_shapes` passed to `make_mesh` should be a sequence of ints"):
4465+
jax.make_mesh(((4,), 4), ('x', 'y'))
4466+
4467+
jax.make_mesh((1, np.int32(1), np.int64(1)), ('x', 'y', 'z')) # doesn't crash
4468+
44614469
def test_jnp_array_reshard_error(self):
44624470
if jax.device_count() < 2:
44634471
self.skipTest('Requires >=2 devices')

0 commit comments

Comments
 (0)