Skip to content

Commit a158e02

Browse files
Fabian MentzerGoogle-ML-Automation
authored andcommitted
Reverts cc5036c
PiperOrigin-RevId: 700998046
1 parent 34fe66b commit a158e02

File tree

4 files changed

+1
-21
lines changed

4 files changed

+1
-21
lines changed

jax/_src/mesh_utils.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -705,12 +705,6 @@ def _transpose_trick(
705705
*_TRANSPOSE_TRICKS[topology][mesh_shape_no_trivial_dims]
706706
)
707707

708-
def _validate_axis_shapes(axis_shapes: Sequence[int], arg_name: str,
709-
fun_name: str):
710-
if not all(isinstance(s, int) for s in axis_shapes):
711-
raise ValueError(
712-
f'{arg_name} passed to {fun_name} should be a sequence of ints. Got'
713-
f' {axis_shapes}')
714708

715709
def create_device_mesh(
716710
mesh_shape: Sequence[int],
@@ -746,8 +740,7 @@ def create_device_mesh(
746740
"""
747741
if devices is None:
748742
devices = xb.devices()
749-
_validate_axis_shapes(mesh_shape, 'mesh_shape', 'create_device_mesh')
750-
if math.prod(mesh_shape) != len(devices):
743+
if np.prod(mesh_shape) != len(devices):
751744
raise ValueError(
752745
f'Number of devices {len(devices)} must equal the product '
753746
f'of mesh_shape {mesh_shape}'

jax/_src/sharding_impls.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1714,7 +1714,6 @@ def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str],
17141714
"""
17151715
if devices is None:
17161716
devices = xla_bridge.devices()
1717-
mesh_utils._validate_axis_shapes(axis_shapes, 'axis_shapes', 'make_mesh')
17181717
axis_size = math.prod(axis_shapes)
17191718
if axis_size > len(devices):
17201719
raise ValueError(

tests/mesh_utils_test.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -353,12 +353,6 @@ def test_create_device_mesh_for_nd_torus(
353353
)
354354
self.assertArraysEqual(assignment, expected_assignment_matrix)
355355

356-
def test_create_device_mesh_non_int_error(self):
357-
with self.assertRaisesRegex(
358-
ValueError,
359-
"mesh_shape passed to create_device_mesh should be a sequence of ints"):
360-
mesh_utils.create_device_mesh(((4,), 4))
361-
362356
@parameterized.named_parameters(
363357
('2x2x1', mock_2x2x1_devices,),
364358
('2x2x4', mock_2x2x4_devices, ),

tests/pjit_test.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4458,12 +4458,6 @@ def g(x):
44584458
self.assertEqual(out2.sharding, s)
44594459
self.assertEqual(out2.dtype, np.float32)
44604460

4461-
def test_make_mesh_non_int_error(self):
4462-
with self.assertRaisesRegex(
4463-
ValueError,
4464-
"axis_shapes passed to make_mesh should be a sequence of ints"):
4465-
jax.make_mesh(((4,), 4), ('x', 'y'))
4466-
44674461
def test_jnp_array_reshard_error(self):
44684462
if jax.device_count() < 2:
44694463
self.skipTest('Requires >=2 devices')

0 commit comments

Comments
 (0)