Skip to content

Commit 801fe87

Browse files
bartchr808Google-ML-Automation
authored andcommitted
Do not allow None axis names in meshes.
PiperOrigin-RevId: 686557025
1 parent bb271aa commit 801fe87

File tree

4 files changed

+10
-4
lines changed

4 files changed

+10
-4
lines changed

jax/_src/mesh.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,8 @@ def __new__(cls, devices: np.ndarray | Sequence[xc.Device],
165165
if isinstance(axis_names, str):
166166
axis_names = (axis_names,)
167167
axis_names = tuple(axis_names)
168+
if not all(i is not None for i in axis_names):
169+
raise ValueError(f"Mesh axis names cannot be None. Got: {axis_names}")
168170

169171
if devices.ndim != len(axis_names):
170172
raise ValueError(

tests/array_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1295,6 +1295,10 @@ def test_to_device(self):
12951295
self.assertEqual(x_device.device, device)
12961296
self.assertEqual(x_sharding.device, sharding)
12971297

1298+
def test_mesh_with_axis_name_none(self):
1299+
with self.assertRaisesRegex(ValueError, 'Mesh axis names cannot be None.'):
1300+
jax.sharding.Mesh(jax.devices(), (None, 'x'))
1301+
12981302

12991303
@jtu.with_config(jax_use_shardy_partitioner=True)
13001304
class ShardyShardingTest(jtu.JaxTestCase):

tests/pallas/tpu_pallas_distributed_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,8 +294,8 @@ def test_kernel(x_ref,
294294
)
295295
)
296296

297-
devices = mesh_utils.create_device_mesh((1, num_devices))
298-
mesh = jax.sharding.Mesh(devices, P(None, 'x'))
297+
devices = mesh_utils.create_device_mesh((num_devices,))
298+
mesh = jax.sharding.Mesh(devices, 'x')
299299
sharding = jax.sharding.NamedSharding(mesh, P(None, 'x'))
300300
unsharded_arr = jax.random.normal(
301301
jax.random.key(0), shape=(8, 128 * num_devices))

tests/pallas/tpu_pallas_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2165,8 +2165,8 @@ def kernel(x_ref, o_ref, send_sem, recv_sem):
21652165
)
21662166
with self.assertRaisesRegex(
21672167
Exception, 'DMAs with bool dtypes are not supported.'):
2168-
devices = mesh_utils.create_device_mesh((1, num_devices))
2169-
mesh = jax.sharding.Mesh(devices, P(None, 'x'))
2168+
devices = mesh_utils.create_device_mesh((num_devices,))
2169+
mesh = jax.sharding.Mesh(devices, ('x',))
21702170
sharding = jax.sharding.NamedSharding(mesh, P(None, 'x'))
21712171
input_arr = jax.device_put(input_arr, sharding)
21722172
jax.jit(

0 commit comments

Comments
 (0)