Skip to content

Commit b5e4fd1

Browse files
yashk2810Google-ML-Automation
authored andcommitted
[sharding_in_types] Enforce AxisTypes to always exist if set_mesh is used.
Also support `Auto` mode fully or mixed in with `User` mode. This works by overriding the sharding of `Auto` axes in the PartitionSpec with `Unconstrained` in `ShapedArray` constructor. The `ShapedArray` constructor is the central place where we can make such substitutions. During lowering of shardings with auto axes, we mark the auto dims are `unspecifed_dims`. We don't mark all dims as unspecified because that would enable XLA to shard them even further which is not what we want if some of the dims are user sharded. PiperOrigin-RevId: 704911253
1 parent e88b578 commit b5e4fd1

File tree

10 files changed

+157
-56
lines changed

10 files changed

+157
-56
lines changed

jax/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,7 @@ pytype_strict_library(
455455
":dtypes",
456456
":effects",
457457
":mesh",
458+
":partition_spec",
458459
":pretty_printer",
459460
":source_info_util",
460461
":traceback_util",
@@ -558,6 +559,7 @@ pytype_strict_library(
558559
":layout",
559560
":op_shardings",
560561
":partial_eval",
562+
":partition_spec",
561563
":path",
562564
":pickle_util",
563565
":sharding",

jax/_src/core.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from jax._src import effects
4040
from jax._src import compute_on
4141
from jax._src import mesh as mesh_lib
42+
from jax._src.partition_spec import PartitionSpec as P, UnconstrainedSingleton
4243
from jax._src.errors import (
4344
ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError,
4445
TracerIntegerConversionError, UnexpectedTracerError)
@@ -1599,13 +1600,30 @@ def _invalid_shape_error(shape: Shape, context: str=""):
15991600

16001601
return TypeError(msg)
16011602

1603+
# TODO(yashkatariya): Only works with User/Auto. Generalize it to work with
1604+
# Collective too.
1605+
def _maybe_modify_sharding(sharding):
1606+
if mesh_lib.AxisTypes.Auto not in sharding.mesh.axis_types:
1607+
return sharding
1608+
1609+
new_spec = []
1610+
for s in sharding.spec:
1611+
if s is None or isinstance(s, UnconstrainedSingleton):
1612+
new_spec.append(s)
1613+
else:
1614+
temp_s = s[0] if isinstance(s, tuple) else s
1615+
new_spec.append(
1616+
P.UNCONSTRAINED
1617+
if sharding.mesh._name_to_type[temp_s] == mesh_lib.AxisTypes.Auto else s)
1618+
return sharding.with_spec(new_spec)
1619+
16021620

16031621
def get_sharding(sharding, ndim):
1604-
from jax._src.sharding_impls import NamedSharding, PartitionSpec as P # type: ignore
1622+
from jax._src.sharding_impls import NamedSharding # type: ignore
16051623

16061624
if sharding is not None:
16071625
assert len(sharding.spec) == ndim
1608-
return sharding
1626+
return _maybe_modify_sharding(sharding)
16091627

16101628
context_mesh = mesh_lib.get_abstract_mesh()
16111629
# TODO(yashkatariya): Error out and ask users to set the context mesh in their
@@ -1675,9 +1693,7 @@ def str_short(self, short_dtypes=False):
16751693
dt_str = dt_str.replace('void', 'float0')
16761694
if hasattr(self, 'sharding') and self.sharding is not None:
16771695
shapestr = _get_shape_sharding_str(self.shape, self.sharding.spec)
1678-
axis_types = self.sharding.mesh.axis_types
1679-
axt = _get_axis_type_str(axis_types) if axis_types is not None else ''
1680-
return f'{dt_str}[{shapestr}]{axt}'
1696+
return f'{dt_str}[{shapestr}]'
16811697
else:
16821698
shapestr = ','.join(map(str, self.shape))
16831699
return f'{dt_str}[{shapestr}]'
@@ -1689,26 +1705,13 @@ def _len(self, ignored_tracer):
16891705
raise TypeError("len() of unsized object") from err # same as numpy error
16901706

16911707

1692-
def _get_axis_type_str(axis_types):
1693-
from jax._src.mesh import AxisTypes # type: ignore
1694-
1695-
out = []
1696-
for t, axes in axis_types.items():
1697-
a = f"({','.join(a for a in axes)})" if isinstance(axes, tuple) else axes
1698-
if t == AxisTypes.Collective:
1699-
out.append(f"C:{a}")
1700-
elif t == AxisTypes.User:
1701-
out.append(f"U:{a}")
1702-
else:
1703-
assert t == AxisTypes.Auto
1704-
out.append(f"A:{a}")
1705-
return f"{{{', '.join(out)}}}"
1706-
17071708
def _get_shape_sharding_str(shape, spec):
17081709
out = []
17091710
for s1, s2 in zip(shape, spec):
17101711
if s2 is None:
17111712
out.append(f"{s1}")
1713+
elif isinstance(s2, UnconstrainedSingleton):
1714+
out.append(f"{s1}")
17121715
elif isinstance(s2, tuple):
17131716
ss = ','.join(s for s in s2)
17141717
out.append(f"{s1}@({ss})")

jax/_src/interpreters/mlir.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from jax._src.layout import AutoLayout, DeviceLocalLayout
5151
from jax._src.sharding import Sharding as JSharding
5252
from jax._src.sharding_impls import AUTO
53+
from jax._src.partition_spec import UnconstrainedSingleton
5354
from jax._src.lib import xla_client as xc
5455
from jax._src.lib import xla_extension
5556
from jax._src.lib.mlir import dialects, ir, passmanager
@@ -2524,12 +2525,19 @@ def lower_sharding_under_shit(ctx, op, aval, sharding_proto=None):
25242525
# Don't emit a wsc under full manual mode to avoid increasing HLO size.
25252526
if aval.sharding.mesh._are_all_axes_collective:
25262527
return op
2528+
if aval.sharding.mesh._are_all_axes_auto:
2529+
return op
2530+
# TODO(yashkatariya): If all the axes in pspec are AUTO or collective,
2531+
# `return op` early and avoid bloating HLO size.
25272532
proto = (aval.sharding._to_xla_hlo_sharding(aval.ndim).to_proto()
25282533
if sharding_proto is None else sharding_proto)
2529-
# TODO(yashkatariya): Enable this
2530-
# unspecified_dims = (set(range(aval.ndim))
2531-
# if aval.sharding.mesh._any_axis_collective else None)
2532-
return wrap_with_sharding_op(ctx, op, aval, proto)
2534+
unspecified_dims = None
2535+
if aval.sharding.mesh._any_axis_collective:
2536+
unspecified_dims = set(range(aval.ndim))
2537+
elif aval.sharding.mesh._any_axis_auto:
2538+
unspecified_dims = {i for i, s in enumerate(aval.sharding.spec)
2539+
if isinstance(s, UnconstrainedSingleton)}
2540+
return wrap_with_sharding_op(ctx, op, aval, proto, unspecified_dims)
25332541

25342542

25352543
def set_sharding(op, sharding: xc.OpSharding | sharding_impls.SdyArraySharding):

jax/_src/interpreters/pxla.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
from jax._src.lib import xla_client as xc
6464
from jax._src.lib.mlir import ir
6565
from jax._src.lib.mlir.dialects import hlo
66-
from jax._src.partition_spec import PartitionSpec
66+
from jax._src.partition_spec import PartitionSpec, UnconstrainedSingleton
6767
from jax._src.sharding import Sharding as JSharding
6868
from jax._src.sharding_impls import (
6969
ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UNSPECIFIED,
@@ -2123,11 +2123,13 @@ def _concretize_abstract_shardings(shardings, avals, device_assignment):
21232123
@lru_cache(maxsize=128)
21242124
def _abstract_to_concrete_mesh(abstract_mesh):
21252125
return mesh_lib.Mesh(
2126-
np_dev.reshape(abstract_mesh.axis_sizes), abstract_mesh.axis_names)
2126+
np_dev.reshape(abstract_mesh.axis_sizes), abstract_mesh.axis_names,
2127+
axis_types=abstract_mesh.axis_types)
21272128

21282129
out = []
21292130
for s, a in zip(shardings, avals):
2130-
if isinstance(s, UnspecifiedValue) and a.sharding is not None:
2131+
if (isinstance(s, UnspecifiedValue) and a.sharding is not None and
2132+
all(not isinstance(s, UnconstrainedSingleton) for s in a.sharding.spec)):
21312133
out.append(NamedSharding(_abstract_to_concrete_mesh(a.sharding.mesh),
21322134
a.sharding.spec))
21332135
else:

jax/_src/mesh.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def axis_names_to_types(axis_types) -> dict[str, AxisTypes]:
124124

125125
_mesh_object_dict = {} # type: ignore
126126

127+
MeshAxisType = dict[AxisTypes, str | tuple[str, ...]]
127128

128129
class Mesh(contextlib.ContextDecorator):
129130
"""Declare the hardware resources available in the scope of this manager.
@@ -178,11 +179,11 @@ class Mesh(contextlib.ContextDecorator):
178179

179180
devices: np.ndarray
180181
axis_names: tuple[MeshAxisName, ...]
181-
axis_types: dict[AxisTypes, str | tuple[str, ...]] | None
182+
axis_types: MeshAxisType | None
182183

183184
def __new__(cls, devices: np.ndarray | Sequence[xc.Device],
184-
axis_names: str | Sequence[MeshAxisName],
185-
axis_types: dict[AxisTypes, str | tuple[str, ...]] | None = None):
185+
axis_names: str | Sequence[MeshAxisName], *,
186+
axis_types: MeshAxisType | None = None):
186187
if not isinstance(devices, np.ndarray):
187188
devices = np.array(devices)
188189
if isinstance(axis_names, str):
@@ -216,7 +217,8 @@ def __new__(cls, devices: np.ndarray | Sequence[xc.Device],
216217
return self
217218

218219
def __reduce__(self):
219-
return (type(self), (self.devices, self.axis_names, self.axis_types))
220+
return (type(self), (self.devices, self.axis_names),
221+
{'axis_types': self.axis_types})
220222

221223
def __eq__(self, other):
222224
if not isinstance(other, Mesh):
@@ -348,7 +350,7 @@ def local_devices(self):
348350

349351
@functools.cached_property
350352
def abstract_mesh(self):
351-
return AbstractMesh(self.shape_tuple, self.axis_types)
353+
return AbstractMesh(self.shape_tuple, axis_types=self.axis_types)
352354

353355

354356
EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ()))
@@ -373,8 +375,8 @@ class AbstractMesh:
373375
details.
374376
"""
375377

376-
def __init__(self, shape_tuple: tuple[tuple[str, int], ...],
377-
axis_types: dict[AxisTypes, str | tuple[str, ...]] | None = None):
378+
def __init__(self, shape_tuple: tuple[tuple[str, int], ...], *,
379+
axis_types: MeshAxisType | None = None):
378380
self.shape_tuple = shape_tuple
379381
self.axis_types = axis_types
380382
if self.shape_tuple:
@@ -434,6 +436,24 @@ def _are_all_axes_collective(self) -> bool:
434436
return False
435437
return all(t == AxisTypes.Collective for t in self.axis_types.keys())
436438

439+
@functools.cached_property
440+
def _are_all_axes_auto(self) -> bool:
441+
if self.axis_types is None:
442+
return False
443+
return all(t == AxisTypes.Auto for t in self.axis_types.keys())
444+
445+
@functools.cached_property
446+
def _any_axis_collective(self) -> bool:
447+
if self.axis_types is None:
448+
return False
449+
return any(t == AxisTypes.Collective for t in self.axis_types.keys())
450+
451+
@functools.cached_property
452+
def _any_axis_auto(self) -> bool:
453+
if self.axis_types is None:
454+
return False
455+
return any(t == AxisTypes.Auto for t in self.axis_types.keys())
456+
437457
@property
438458
def devices(self):
439459
_raise_value_error("devices")
@@ -474,6 +494,8 @@ def _raise_value_error(name):
474494

475495
@contextlib.contextmanager
476496
def set_abstract_mesh(mesh: AbstractMesh):
497+
if mesh is not None and mesh.axis_types is None:
498+
raise RuntimeError('Please set the AxisTypes of Mesh.')
477499
prev_val = jax_config.abstract_mesh_context_manager.swap_local(mesh)
478500
try:
479501
yield

jax/_src/partition_spec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from __future__ import annotations
1616

17-
class _UnconstrainedPartitionSingleton:
17+
class UnconstrainedSingleton:
1818

1919
def __repr__(self):
2020
return "UNCONSTRAINED"
@@ -23,7 +23,7 @@ def __repr__(self):
2323
# Unconstrained sentinel value for PartitionSpec, representing a dimension for
2424
# which the user wants XLA to assign the best partitioning.
2525
# TODO(yashkatariya): May rename to AUTO.
26-
_UNCONSTRAINED_PARTITION = _UnconstrainedPartitionSingleton()
26+
_UNCONSTRAINED_PARTITION = UnconstrainedSingleton()
2727

2828

2929
class PartitionSpec(tuple):

jax/_src/sharding_impls.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,19 @@ def _check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes):
6767
f"is also found in manual_axes: {_manual_axes}.") from None
6868

6969

70+
@util.cache(max_size=128, trace_context_in_key=False)
71+
def _check_axis_type_consistency(mesh, parsed_pspec):
72+
if mesh.axis_types is None:
73+
return
74+
for p in parsed_pspec:
75+
if p is not None:
76+
if not all(mesh._name_to_type[p[0]] == mesh._name_to_type[r] for r in p):
77+
raise ValueError(
78+
'AxisTypes should be the same in a tuple subset of PartitionSpec:'
79+
f' {parsed_pspec.get_partition_spec()}. Got subset {p} with axis'
80+
f' types: ({", ".join(str(mesh._name_to_type[r]) for r in p)})')
81+
82+
7083
def hashed_index(x) -> int:
7184
# This works for both `pjit` indices and `pmap` indices (which might
7285
# have an integer instead of a slice).
@@ -1084,6 +1097,7 @@ def preprocess(mesh, spec, parsed_pspec, _manual_axes=frozenset()):
10841097
PartitionSpec() if spec is None else spec,
10851098
"NamedSharding spec", allow_unconstrained_dims=True)
10861099
_check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes)
1100+
_check_axis_type_consistency(mesh, parsed_pspec)
10871101
return parsed_pspec
10881102

10891103

@@ -1673,7 +1687,8 @@ def _gspmd_to_named_sharding_via_mesh(
16731687

16741688

16751689
def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str],
1676-
*, devices: Sequence[xc.Device] | None = None) -> mesh_lib.Mesh:
1690+
*, devices: Sequence[xc.Device] | None = None,
1691+
axis_types: mesh_lib.MeshAxisType | None = None) -> mesh_lib.Mesh:
16771692
"""Creates an efficient mesh with the shape and axis names specified.
16781693
16791694
This function attempts to automatically compute a good mapping from a set of
@@ -1735,4 +1750,4 @@ def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str],
17351750
mesh_devices = mesh_utils.create_device_mesh(
17361751
new_axis_shapes, devices,
17371752
allow_split_physical_axes=allow_split_physical_axes)
1738-
return mesh_lib.Mesh(mesh_devices, axis_names)
1753+
return mesh_lib.Mesh(mesh_devices, axis_names, axis_types=axis_types)

jax/_src/test_util.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1443,26 +1443,28 @@ def with_and_without_mesh(f):
14431443
('Mesh', (('x', 2),), (('i', 'x'),))
14441444
))(with_mesh_from_kwargs(f))
14451445

1446-
def with_user_mesh(sizes, names):
1446+
def with_user_mesh(sizes, names, axis_types=None):
1447+
axis_types = ({mesh_lib.AxisTypes.User: names}
1448+
if axis_types is None else axis_types)
14471449
def decorator(fn):
14481450
def mesh_fn(*args, **kwargs):
1449-
mesh = create_mesh(sizes, names)
1451+
mesh = create_mesh(sizes, names, axis_types=axis_types)
14501452
with mesh_lib.set_mesh(mesh):
14511453
return fn(*args, **kwargs, mesh=mesh)
14521454
return mesh_fn
14531455
return decorator
14541456

14551457

1456-
def create_mesh(mesh_shape, axis_names, iota_order=False):
1458+
def create_mesh(mesh_shape, axis_names, iota_order=False, axis_types=None):
14571459
size = math.prod(mesh_shape)
14581460
if len(jax.devices()) < size:
14591461
raise unittest.SkipTest(f"Test requires {size} global devices.")
14601462
if iota_order:
14611463
devices = sorted(jax.devices(), key=lambda d: d.id)
14621464
mesh_devices = np.array(devices[:size]).reshape(mesh_shape)
1463-
return jax.sharding.Mesh(mesh_devices, axis_names)
1465+
return jax.sharding.Mesh(mesh_devices, axis_names, axis_types=axis_types)
14641466
else:
1465-
return jax.make_mesh(mesh_shape, axis_names)
1467+
return jax.make_mesh(mesh_shape, axis_names, axis_types=axis_types)
14661468

14671469
class _cached_property:
14681470
null = object()

jax/experimental/shard_map.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@
4646
from jax._src import traceback_util
4747
from jax._src import util
4848
from jax._src.core import Tracer
49-
from jax._src.mesh import AbstractMesh, Mesh, AxisTypes, set_abstract_mesh
49+
from jax._src.mesh import (AbstractMesh, Mesh, AxisTypes, set_abstract_mesh,
50+
get_abstract_mesh)
5051
from jax._src.api import _shared_code_pmap, _prepare_pmap
5152
from jax._src.lax import (lax, parallel as lax_parallel, slicing,
5253
windowed_reductions, convolution, fft, linalg,
@@ -536,7 +537,7 @@ def _shard_shaped_array(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
536537
for i, sz in enumerate(aval.shape))
537538
if config.sharding_in_types.value:
538539
new_mesh = AbstractMesh(
539-
mesh.shape_tuple, {AxisTypes.Collective: mesh.axis_names})
540+
mesh.shape_tuple, axis_types={AxisTypes.Collective: mesh.axis_names})
540541
new_sharding = NamedSharding(new_mesh, P(*[None] * aval.ndim))
541542
else:
542543
new_sharding = None
@@ -548,11 +549,9 @@ def _unshard_shaped_array(mesh: Mesh, names: AxisNames,
548549
assert isinstance(aval, core.ShapedArray)
549550
new_shape = tuple(sz * prod(mesh.shape[n] for n in names.get(i, ()))
550551
for i, sz in enumerate(aval.shape))
551-
# TODO(yashkatariya): Reset the mesh properly based on the input avals if the
552-
# mesh of shard_map specifies collective axes.
553552
if config.sharding_in_types.value:
554553
spec = _names_to_pspec(names)._normalized_spec(aval.ndim)
555-
new_sharding = NamedSharding(AbstractMesh(mesh.shape_tuple), spec)
554+
new_sharding = NamedSharding(get_abstract_mesh(), spec)
556555
else:
557556
new_sharding = None
558557
return aval.update(shape=new_shape, sharding=new_sharding)

0 commit comments

Comments
 (0)