Skip to content

Commit 41f490a

Browse files
yashk2810Google-ML-Automation
authored andcommitted
[sharding_in_types] Default axis_types to Auto for all axis_names if user does not set any AxisType. Also resolve some TODOs now that we have a way for user to set the mesh.
PiperOrigin-RevId: 704944255
1 parent b5e4fd1 commit 41f490a

File tree

4 files changed

+34
-42
lines changed

4 files changed

+34
-42
lines changed

jax/_src/core.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1626,10 +1626,8 @@ def get_sharding(sharding, ndim):
16261626
return _maybe_modify_sharding(sharding)
16271627

16281628
context_mesh = mesh_lib.get_abstract_mesh()
1629-
# TODO(yashkatariya): Error out and ask users to set the context mesh in their
1630-
# code.
16311629
if not context_mesh:
1632-
return None
1630+
return RuntimeError("Please set the mesh via `jax.set_mesh` API.")
16331631
assert sharding is None
16341632
return NamedSharding(context_mesh, P(*[None] * ndim))
16351633

@@ -1692,7 +1690,7 @@ def str_short(self, short_dtypes=False):
16921690
self.dtype.name)
16931691
dt_str = dt_str.replace('void', 'float0')
16941692
if hasattr(self, 'sharding') and self.sharding is not None:
1695-
shapestr = _get_shape_sharding_str(self.shape, self.sharding.spec)
1693+
shapestr = _get_shape_sharding_str(self.shape, self.sharding.spec) # type: ignore
16961694
return f'{dt_str}[{shapestr}]'
16971695
else:
16981696
shapestr = ','.join(map(str, self.shape))
@@ -2658,16 +2656,10 @@ def substitute(aval: AbstractValue):
26582656
return aval
26592657
for v, x in zip(call_jaxpr.invars, in_atoms):
26602658
if not typecompat(substitute(v.aval), x.aval):
2661-
# TODO(yashkatariya): Remove this once numpy array's aval has a sharding
2662-
# on it.
2663-
if (config.sharding_in_types.value and isinstance(x, Literal) and
2664-
v.aval.sharding is not None and x.val.ndim == 0):
2665-
pass
2666-
else:
2667-
# TODO(mattjj): vars in error message are confusing b/c of Var.__repr__
2668-
raise JaxprTypeError(f"Call primitive {prim} passes operand {x} of type "
2669-
f"{x.aval} to jaxpr expecting type "
2670-
f"{substitute(v.aval)}")
2659+
# TODO(mattjj): vars in error message are confusing b/c of Var.__repr__
2660+
raise JaxprTypeError(f"Call primitive {prim} passes operand {x} of type "
2661+
f"{x.aval} to jaxpr expecting type "
2662+
f"{substitute(v.aval)}")
26712663
env[v] = x if type(x) is Var else x.val
26722664

26732665
_check_jaxpr(ctx_factory, call_jaxpr)

jax/_src/mesh.py

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,6 @@ def __repr__(self):
111111
return self.name
112112

113113
def axis_names_to_types(axis_types) -> dict[str, AxisTypes]:
114-
if axis_types is None:
115-
return {}
116114
d = {}
117115
for t, names in axis_types.items():
118116
if isinstance(names, tuple):
@@ -179,7 +177,7 @@ class Mesh(contextlib.ContextDecorator):
179177

180178
devices: np.ndarray
181179
axis_names: tuple[MeshAxisName, ...]
182-
axis_types: MeshAxisType | None
180+
axis_types: MeshAxisType
183181

184182
def __new__(cls, devices: np.ndarray | Sequence[xc.Device],
185183
axis_names: str | Sequence[MeshAxisName], *,
@@ -199,9 +197,9 @@ def __new__(cls, devices: np.ndarray | Sequence[xc.Device],
199197
f"devices.ndim == {devices.ndim} and "
200198
f"len(axis_names) == {len(axis_names)}.")
201199

202-
# TODO(yashkatariya): If axis_types is None, set all axes to AUTO.
203-
axis_types_tuple = (None if axis_types is None else
204-
tuple(axis_types.items()))
200+
axis_types = ({AxisTypes.Auto: axis_names} if axis_types is None else
201+
axis_types)
202+
axis_types_tuple = tuple(axis_types.items())
205203
key = (axis_names, devices.shape, tuple(devices.flat), axis_types_tuple)
206204
val = _mesh_object_dict.get(key, None)
207205
if val is not None:
@@ -337,7 +335,7 @@ def __str__(self):
337335
def _repr(self):
338336
if self.empty:
339337
return "Mesh(device_ids=[], axis_names=())"
340-
atr = '' if self.axis_types is None else f", axis_types={self.axis_types}"
338+
atr = f", axis_types={self.axis_types}"
341339
return f"Mesh(device_ids={self.device_ids!r}, axis_names={self.axis_names!r}{atr})"
342340

343341
def __repr__(self):
@@ -378,14 +376,13 @@ class AbstractMesh:
378376
def __init__(self, shape_tuple: tuple[tuple[str, int], ...], *,
379377
axis_types: MeshAxisType | None = None):
380378
self.shape_tuple = shape_tuple
381-
self.axis_types = axis_types
382379
if self.shape_tuple:
383380
self._axis_names, self._axis_sizes = list(zip(*self.shape_tuple))
384381
else:
385382
self._axis_names, self._axis_sizes = (), ()
386-
# TODO(yashkatariya): If axis_types is None, set all axes to AUTO.
387-
self._axis_types_tuple = (None if axis_types is None else
388-
tuple(axis_types.items()))
383+
self.axis_types = ({AxisTypes.Auto: self._axis_names} if axis_types is None
384+
else axis_types)
385+
self._axis_types_tuple = tuple(self.axis_types.items())
389386

390387
def __hash__(self):
391388
return hash((self.shape_tuple, self._axis_types_tuple))
@@ -399,7 +396,7 @@ def __eq__(self, other):
399396
self._axis_types_tuple == other._axis_types_tuple)
400397

401398
def __repr__(self):
402-
atr = '' if self.axis_types is None else f", axis_types={self.axis_types}"
399+
atr = f", axis_types={self.axis_types}"
403400
return f"AbstractMesh({self.shape_tuple}{atr})"
404401

405402
@property
@@ -432,26 +429,18 @@ def empty(self):
432429

433430
@functools.cached_property
434431
def _are_all_axes_collective(self) -> bool:
435-
if self.axis_types is None:
436-
return False
437432
return all(t == AxisTypes.Collective for t in self.axis_types.keys())
438433

439434
@functools.cached_property
440435
def _are_all_axes_auto(self) -> bool:
441-
if self.axis_types is None:
442-
return False
443436
return all(t == AxisTypes.Auto for t in self.axis_types.keys())
444437

445438
@functools.cached_property
446439
def _any_axis_collective(self) -> bool:
447-
if self.axis_types is None:
448-
return False
449440
return any(t == AxisTypes.Collective for t in self.axis_types.keys())
450441

451442
@functools.cached_property
452443
def _any_axis_auto(self) -> bool:
453-
if self.axis_types is None:
454-
return False
455444
return any(t == AxisTypes.Auto for t in self.axis_types.keys())
456445

457446
@property
@@ -494,8 +483,6 @@ def _raise_value_error(name):
494483

495484
@contextlib.contextmanager
496485
def set_abstract_mesh(mesh: AbstractMesh):
497-
if mesh is not None and mesh.axis_types is None:
498-
raise RuntimeError('Please set the AxisTypes of Mesh.')
499486
prev_val = jax_config.abstract_mesh_context_manager.swap_local(mesh)
500487
try:
501488
yield

jax/_src/pjit.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -698,9 +698,6 @@ def get_abstract_mesh_from_avals(in_avals):
698698
return None
699699
m = None
700700
for a in in_avals:
701-
# TODO(yashkatariya): Remove this when mesh context can be set by the user.
702-
if a.sharding is None: # type: ignore
703-
continue
704701
if m is not None and m != a.sharding.mesh:
705702
raise ValueError(
706703
f'Mesh for all inputs should be equal. Got one mesh: {m} and'
@@ -1788,9 +1785,7 @@ def _pjit_lower(
17881785
lowering_parameters: mlir.LoweringParameters,
17891786
pgle_profiler: profiler.PGLEProfiler | None):
17901787
if config.sharding_in_types.value:
1791-
cur_mesh = mesh_lib.get_concrete_mesh()
1792-
mesh = cur_mesh if isinstance(cur_mesh, mesh_lib.Mesh) else None
1793-
api_name = 'jit'
1788+
mesh, api_name = mesh_lib.get_concrete_mesh(), 'jit'
17941789
else:
17951790
mesh, api_name = ((resource_env.physical_mesh, 'pjit')
17961791
if resource_env is not None else (None, 'jit'))

tests/pjit_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5483,6 +5483,24 @@ def f(x):
54835483

54845484
self.assertIn('@Sharding', f.lower(arr).as_text())
54855485

5486+
@jtu.with_user_mesh((2, 2), ('x', 'y'), {mesh_lib.AxisTypes.Auto: ('x', 'y')})
5487+
def test_only_auto(self, mesh):
5488+
np_inp = np.arange(16.).reshape(8, 2)
5489+
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None)))
5490+
5491+
@jax.jit
5492+
def f(x, x2):
5493+
y = x * 2
5494+
self.assertEqual(y.sharding.spec, P(P.UNCONSTRAINED, None))
5495+
z = jnp.sin(y)
5496+
self.assertEqual(z.sharding.spec, P(P.UNCONSTRAINED, None))
5497+
a = z @ x2
5498+
self.assertEqual(a.sharding.spec, P(P.UNCONSTRAINED, P.UNCONSTRAINED))
5499+
return a
5500+
5501+
out = f(arr, arr.T)
5502+
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))
5503+
54865504
def test_auto_user(self):
54875505
mesh = jtu.create_mesh((2, 2), ('x', 'y'),
54885506
axis_types={mesh_lib.AxisTypes.Auto: ('x', 'y')})

0 commit comments

Comments
 (0)