Skip to content

Commit 05716b5

Browse files
yashk2810Google-ML-Automation
authored andcommitted
[sharding_in_types] Support shard_map with sharding in types. Right now only full manual mode is supported.
This change also adds AxisTypes to Mesh which are `User`, `Auto` and `Collective`. In the following changes, I'll remove the `config.sharding_in_types` flag and we'll enter into various modes via AxisTypes mentioned on the mesh. PiperOrigin-RevId: 696559375
1 parent a8464ce commit 05716b5

File tree

7 files changed

+134
-42
lines changed

7 files changed

+134
-42
lines changed

jax/_src/core.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1656,8 +1656,10 @@ def str_short(self, short_dtypes=False):
16561656
self.dtype.name)
16571657
dt_str = dt_str.replace('void', 'float0')
16581658
if hasattr(self, 'sharding') and self.sharding is not None:
1659-
shapestr = ','.join(_get_shape_sharding_str(self.shape, self.sharding.spec))
1660-
return f'{dt_str}[{shapestr}]'
1659+
shapestr = _get_shape_sharding_str(self.shape, self.sharding.spec)
1660+
axis_types = self.sharding.mesh.axis_types
1661+
axt = _get_axis_type_str(axis_types) if axis_types is not None else ''
1662+
return f'{dt_str}[{shapestr}]{axt}'
16611663
else:
16621664
shapestr = ','.join(map(str, self.shape))
16631665
return f'{dt_str}[{shapestr}]'
@@ -1669,15 +1671,32 @@ def _len(self, ignored_tracer):
16691671
raise TypeError("len() of unsized object") from err # same as numpy error
16701672

16711673

1674+
def _get_axis_type_str(axis_types):
1675+
from jax._src.mesh import AxisTypes # type: ignore
1676+
1677+
out = []
1678+
for t, axes in axis_types.items():
1679+
a = f"({','.join(a for a in axes)})" if isinstance(axes, tuple) else axes
1680+
if t == AxisTypes.Collective:
1681+
out.append(f"C:{a}")
1682+
elif t == AxisTypes.User:
1683+
out.append(f"U:{a}")
1684+
else:
1685+
assert t == AxisTypes.Auto
1686+
out.append(f"A:{a}")
1687+
return f"{{{', '.join(out)}}}"
1688+
16721689
def _get_shape_sharding_str(shape, spec):
1690+
out = []
16731691
for s1, s2 in zip(shape, spec):
16741692
if s2 is None:
1675-
yield f"{s1}"
1693+
out.append(f"{s1}")
16761694
elif isinstance(s2, tuple):
16771695
ss = ','.join(s for s in s2)
1678-
yield f"{s1}@({ss})"
1696+
out.append(f"{s1}@({ss})")
16791697
else:
1680-
yield f"{s1}@{s2}"
1698+
out.append(f"{s1}@{s2}")
1699+
return ','.join(out)
16811700

16821701
def _get_abstract_sharding(val):
16831702
from jax._src.sharding_impls import NamedSharding # pytype: disable=import-error

jax/_src/lax/lax.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2203,14 +2203,13 @@ def multi_sharding_in_dim(ctx, ops, in_avals, out_aval):
22032203
for op, in_aval in zip(ops, in_avals):
22042204
if in_aval.sharding == out_aval.sharding or in_aval.sharding is None:
22052205
out.append(op)
2206+
elif in_aval.sharding.mesh.are_all_axes_collective:
2207+
out.append(op)
22062208
else:
22072209
# TODO(yashkatariya, dougalm): If `in_aval.sharding` contains
22082210
# CompilerShardingAxis, then specify `unspecified_dims` via
22092211
# `wrap_with_sharding_op`.
2210-
if config.use_shardy_partitioner.value:
2211-
sp = in_aval.sharding._to_sdy_sharding(in_aval.ndim)
2212-
else:
2213-
sp = in_aval.sharding._to_xla_hlo_sharding(in_aval.ndim).to_proto()
2212+
sp = in_aval.sharding._to_xla_hlo_sharding(in_aval.ndim).to_proto()
22142213
out.append(mlir.wrap_with_sharding_op(ctx, op, out_aval, sp))
22152214
return out
22162215

@@ -2227,10 +2226,9 @@ def _nary_lower_hlo(op: Callable, ctx,
22272226

22282227
out = op(*args)
22292228
if config.sharding_in_types.value:
2230-
if config.use_shardy_partitioner.value:
2231-
out_sp = aval_out.sharding._to_sdy_sharding(aval_out.ndim)
2232-
else:
2233-
out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
2229+
if aval_out.sharding.mesh.are_all_axes_collective:
2230+
return [out]
2231+
out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
22342232
return [mlir.wrap_with_sharding_op(ctx, out, aval_out, out_sp)]
22352233
else:
22362234
return [out]

jax/_src/mesh.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import collections
1919
from collections.abc import Hashable, Sequence
2020
import contextlib
21+
import enum
2122
import functools
2223
import math
2324
import threading
@@ -101,6 +102,12 @@ def _get_local_mesh(global_mesh: Mesh, process_index: int) -> Mesh:
101102
return Mesh(global_mesh.devices[subcube_indices_tuple], global_mesh.axis_names)
102103

103104

105+
class AxisTypes(enum.Enum):
106+
Auto = enum.auto()
107+
User = enum.auto()
108+
Collective = enum.auto()
109+
110+
104111
_mesh_object_dict = {} # type: ignore
105112

106113

@@ -157,9 +164,11 @@ class Mesh(contextlib.ContextDecorator):
157164

158165
devices: np.ndarray
159166
axis_names: tuple[MeshAxisName, ...]
167+
axis_types: dict[AxisTypes, str | tuple[str, ...]] | None
160168

161169
def __new__(cls, devices: np.ndarray | Sequence[xc.Device],
162-
axis_names: str | Sequence[MeshAxisName]):
170+
axis_names: str | Sequence[MeshAxisName],
171+
axis_types: dict[AxisTypes, str | tuple[str, ...]] | None = None):
163172
if not isinstance(devices, np.ndarray):
164173
devices = np.array(devices)
165174
if isinstance(axis_names, str):
@@ -175,7 +184,10 @@ def __new__(cls, devices: np.ndarray | Sequence[xc.Device],
175184
f"devices.ndim == {devices.ndim} and "
176185
f"len(axis_names) == {len(axis_names)}.")
177186

178-
key = (axis_names, devices.shape, tuple(devices.flat))
187+
# TODO(yashkatariya): If axis_types is None, set all axes to AUTO.
188+
axis_types_tuple = (None if axis_types is None else
189+
tuple(axis_types.items()))
190+
key = (axis_names, devices.shape, tuple(devices.flat), axis_types_tuple)
179191
val = _mesh_object_dict.get(key, None)
180192
if val is not None:
181193
return val
@@ -184,11 +196,13 @@ def __new__(cls, devices: np.ndarray | Sequence[xc.Device],
184196
self.devices = devices.copy()
185197
self.devices.flags.writeable = False
186198
self.axis_names = axis_names
199+
self.axis_types = axis_types
200+
self._axis_types_tuple = axis_types_tuple
187201
_mesh_object_dict[key] = self
188202
return self
189203

190204
def __reduce__(self):
191-
return (type(self), (self.devices, self.axis_names))
205+
return (type(self), (self.devices, self.axis_names, self.axis_types))
192206

193207
def __eq__(self, other):
194208
if not isinstance(other, Mesh):
@@ -199,12 +213,14 @@ def __eq__(self, other):
199213
return True
200214
return (self.axis_names == other.axis_names and
201215
self.devices.shape == other.devices.shape and
216+
self._axis_types_tuple == other._axis_types_tuple and
202217
self._internal_device_list == other._internal_device_list)
203218

204219
def __hash__(self):
205220
if not hasattr(self, '_hash'):
206221
self._hash = hash(
207-
(self.axis_names, self._internal_device_list, self.devices.shape))
222+
(self.axis_names, self._internal_device_list, self.devices.shape,
223+
self._axis_types_tuple))
208224
return self._hash
209225

210226
def __setattr__(self, name, value):
@@ -301,7 +317,8 @@ def __str__(self):
301317
def _repr(self):
302318
if self.empty:
303319
return "Mesh(device_ids=[], axis_names=())"
304-
return f"Mesh(device_ids={self.device_ids!r}, axis_names={self.axis_names!r})"
320+
atr = '' if self.axis_types is None else f", axis_types={self.axis_types}"
321+
return f"Mesh(device_ids={self.device_ids!r}, axis_names={self.axis_names!r}{atr})"
305322

306323
def __repr__(self):
307324
return self._repr
@@ -313,7 +330,7 @@ def local_devices(self):
313330

314331
@functools.cached_property
315332
def abstract_mesh(self):
316-
return AbstractMesh(self.shape_tuple)
333+
return AbstractMesh(self.shape_tuple, self.axis_types)
317334

318335

319336
EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ()))
@@ -338,25 +355,32 @@ class AbstractMesh:
338355
details.
339356
"""
340357

341-
def __init__(self, shape_tuple: tuple[tuple[str, int], ...]):
358+
def __init__(self, shape_tuple: tuple[tuple[str, int], ...],
359+
axis_types: dict[AxisTypes, str | tuple[str, ...]] | None = None):
342360
self.shape_tuple = shape_tuple
361+
self.axis_types = axis_types
343362
if self.shape_tuple:
344363
self._axis_names, self._axis_sizes = list(zip(*self.shape_tuple))
345364
else:
346365
self._axis_names, self._axis_sizes = (), ()
366+
# TODO(yashkatariya): If axis_types is None, set all axes to AUTO.
367+
self._axis_types_tuple = (None if axis_types is None else
368+
tuple(axis_types.items()))
347369

348370
def __hash__(self):
349-
return hash(self.shape_tuple)
371+
return hash((self.shape_tuple, self._axis_types_tuple))
350372

351373
def __eq__(self, other):
352374
if not isinstance(other, AbstractMesh):
353375
return False
354376
if id(self) == id(other):
355377
return True
356-
return self.shape_tuple == other.shape_tuple
378+
return (self.shape_tuple == other.shape_tuple and
379+
self._axis_types_tuple == other._axis_types_tuple)
357380

358381
def __repr__(self):
359-
return f"AbstractMesh({self.shape_tuple})"
382+
atr = '' if self.axis_types is None else f", axis_types={self.axis_types}"
383+
return f"AbstractMesh({self.shape_tuple}{atr})"
360384

361385
@property
362386
def axis_names(self):
@@ -382,6 +406,12 @@ def _internal_device_list(self):
382406
def empty(self):
383407
return self.size == 0
384408

409+
@functools.cached_property
410+
def are_all_axes_collective(self) -> bool:
411+
if self.axis_types is None:
412+
return False
413+
return all(t == AxisTypes.Collective for t in self.axis_types.keys())
414+
385415
@property
386416
def devices(self):
387417
_raise_value_error("devices")

jax/_src/partition_spec.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
1517
class _UnconstrainedPartitionSingleton:
1618

1719
def __repr__(self):
@@ -48,3 +50,21 @@ def __repr__(self):
4850

4951
def __reduce__(self):
5052
return (PartitionSpec, tuple(self))
53+
54+
def _normalized_spec(self, ndim: int) -> PartitionSpec:
55+
out = [] # type: ignore
56+
for p in self:
57+
if p is None:
58+
out.append(None)
59+
elif p == self.UNCONSTRAINED:
60+
out.append(p)
61+
elif isinstance(p, (list, tuple)):
62+
if len(p) == 1:
63+
out.append(p[0])
64+
else:
65+
out.append(tuple(p))
66+
else:
67+
out.append(p)
68+
if len(out) < ndim:
69+
out.extend([None] * (ndim - len(out)))
70+
return PartitionSpec(*out)

jax/_src/sharding_impls.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -361,19 +361,7 @@ def with_memory_kind(self, kind: str) -> NamedSharding:
361361
return NamedSharding(self.mesh, self.spec, memory_kind=kind)
362362

363363
def _normalized_spec(self, ndim: int) -> PartitionSpec:
364-
out = [] # type: ignore
365-
for p in self._parsed_pspec:
366-
if p is None:
367-
raise ValueError("UNCONSTRAINED is not supported yet.")
368-
if not p:
369-
out.append(None)
370-
elif isinstance(p, tuple) and len(p) == 1:
371-
out.append(p[0])
372-
else:
373-
out.append(p)
374-
if len(out) < ndim:
375-
out.extend([None] * (ndim - len(out)))
376-
return PartitionSpec(*out)
364+
return self.spec._normalized_spec(ndim)
377365

378366
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
379367
return named_sharding_to_xla_hlo_sharding(self, num_dimensions)

jax/experimental/shard_map.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
from jax._src import traceback_util
4747
from jax._src import util
4848
from jax._src.core import Tracer
49-
from jax._src.mesh import AbstractMesh, Mesh
49+
from jax._src.mesh import AbstractMesh, Mesh, AxisTypes
5050
from jax._src.api import _shared_code_pmap, _prepare_pmap
5151
from jax._src.lax import (lax, parallel as lax_parallel, slicing,
5252
windowed_reductions, convolution, fft, linalg,
@@ -528,17 +528,30 @@ def _unshard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
528528
raise NotImplementedError(f"Unsupported aval type: {type(aval)}")
529529

530530
def _shard_shaped_array(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
531-
) -> core.AbstractValue:
531+
) -> core.AbstractValue:
532532
assert isinstance(aval, core.ShapedArray)
533-
return aval.update(tuple(sz // prod(mesh.shape[n] for n in names.get(i, ()))
534-
for i, sz in enumerate(aval.shape)))
533+
new_shape = tuple(sz // prod(mesh.shape[n] for n in names.get(i, ()))
534+
for i, sz in enumerate(aval.shape))
535+
if config.sharding_in_types.value:
536+
new_mesh = AbstractMesh(
537+
mesh.shape_tuple, {AxisTypes.Collective: mesh.axis_names})
538+
new_sharding = NamedSharding(new_mesh, P(*[None] * aval.ndim))
539+
else:
540+
new_sharding = None
541+
return aval.update(shape=new_shape, sharding=new_sharding)
535542
core.shard_aval_handlers[core.ShapedArray] = _shard_shaped_array
536543

537544
def _unshard_shaped_array(mesh: Mesh, names: AxisNames,
538545
aval: core.AbstractValue,) -> core.AbstractValue:
539546
assert isinstance(aval, core.ShapedArray)
540-
return aval.update(tuple(sz * prod(mesh.shape[n] for n in names.get(i, ()))
541-
for i, sz in enumerate(aval.shape)))
547+
new_shape = tuple(sz * prod(mesh.shape[n] for n in names.get(i, ()))
548+
for i, sz in enumerate(aval.shape))
549+
if config.sharding_in_types.value:
550+
spec = _names_to_pspec(names)._normalized_spec(aval.ndim)
551+
new_sharding = NamedSharding(AbstractMesh(mesh.shape_tuple), spec)
552+
else:
553+
new_sharding = None
554+
return aval.update(shape=new_shape, sharding=new_sharding)
542555
core.unshard_aval_handlers[core.ShapedArray] = _unshard_shaped_array
543556

544557
# Type-checking

tests/pjit_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5201,6 +5201,30 @@ def f(x):
52015201
self.assertArraysEqual(out, np_inp)
52025202
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))
52035203

5204+
def test_shard_map_full_manual(self):
5205+
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
5206+
np_inp = np.arange(16).reshape(8, 2)
5207+
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
5208+
arr2 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
5209+
5210+
def g(x, y):
5211+
self.assertTrue(x.sharding.mesh.are_all_axes_collective)
5212+
self.assertTrue(y.sharding.mesh.are_all_axes_collective)
5213+
return x * y
5214+
5215+
@jax.jit
5216+
def f(x, y):
5217+
z = shard_map(g, mesh=mesh, in_specs=(x.sharding.spec, y.sharding.spec),
5218+
out_specs=P('x', 'y'))(x, y)
5219+
self.assertEqual(z.sharding.spec, P('x', 'y'))
5220+
out = z * 2
5221+
self.assertEqual(out.sharding.spec, P('x', 'y'))
5222+
return out
5223+
5224+
out = f(arr, arr2)
5225+
self.assertArraysEqual(out, (np_inp * np_inp) * 2)
5226+
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y')))
5227+
52045228

52055229
@jtu.pytest_mark_if_available('multiaccelerator')
52065230
class PJitErrorTest(jtu.JaxTestCase):

0 commit comments

Comments
 (0)