Skip to content

Commit 3848f0d

Browse files
yashk2810Google-ML-Automation
authored andcommitted
[sharding_in_types] Functions like einsum, reshape, broadcast_in_dim, broadcasted_iota, convert_element_type and sharding_cast that take out_sharding as an argument in their signature should also allow PartitionSpec instead of just NamedSharding as an input.
If PartitionSpec is passed, the mesh is read from the context. The primitives though take `NamedSharding` only. The conversion from `PartitionSpec` to `NamedSharding` happens above `.bind`. We also raise an error if `PartitionSpec` contain mesh axis names that are of type Auto or Collective for the above functions. PiperOrigin-RevId: 713352542
1 parent c1a60c6 commit 3848f0d

File tree

8 files changed

+120
-61
lines changed

8 files changed

+120
-61
lines changed

jax/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -866,6 +866,7 @@ pytype_strict_library(
866866
":partition_spec",
867867
":sharding",
868868
":sharding_specs",
869+
":source_info_util",
869870
":tree_util",
870871
":util",
871872
":xla_bridge",

jax/_src/core.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1696,6 +1696,9 @@ def __init__(self, shape, dtype, weak_type=False, sharding=None):
16961696
self.weak_type = weak_type
16971697
if config.sharding_in_types.value:
16981698
self.sharding = get_sharding(sharding, len(self.shape))
1699+
if not isinstance(self.sharding.mesh, mesh_lib.AbstractMesh):
1700+
raise ValueError(
1701+
f"Mesh of an aval must be an AbstractMesh. Got {self.sharding.mesh}")
16991702

17001703
def update(self, shape=None, dtype=None, weak_type=None, **kwargs):
17011704
if shape is None:

jax/_src/lax/lax.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
from jax._src.lib.mlir.dialects import chlo
6565
from jax._src.lib.mlir.dialects import hlo
6666
from jax._src.sharding_impls import (PmapSharding, NamedSharding,
67-
PartitionSpec as P)
67+
PartitionSpec as P, canonicalize_sharding)
6868
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DTypeLike, Shape
6969
from jax._src.util import (NumpyComplexWarning, cache, canonicalize_axis,
7070
safe_map, safe_zip, split_list, weakref_lru_cache)
@@ -586,6 +586,8 @@ def _convert_element_type(
586586
isinstance(operand, Array)):
587587
sharding = operand.sharding
588588

589+
sharding = canonicalize_sharding(sharding, check_mesh_consistency=False) # type: ignore
590+
589591
if (warn_on_complex_to_real_cast and
590592
dtypes.issubdtype(old_dtype, np.complexfloating) and
591593
not dtypes.issubdtype(new_dtype, np.complexfloating)):
@@ -1431,6 +1433,7 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape,
14311433
if not config.sharding_in_types.value and sharding is not None:
14321434
raise NotImplementedError("sharding argument to broadcast_in_dim is only "
14331435
"allowed when sharding_in_types config is on.")
1436+
sharding = canonicalize_sharding(sharding)
14341437
if (np.ndim(operand) == len(shape) and not len(broadcast_dimensions) and
14351438
isinstance(operand, Array) and sharding is None):
14361439
return operand
@@ -1505,7 +1508,7 @@ def reshape(operand: ArrayLike, new_sizes: Shape,
15051508
return operand
15061509
else:
15071510
dyn_shape, static_new_sizes = _extract_tracers_dyn_shape(new_sizes)
1508-
1511+
sharding = canonicalize_sharding(sharding)
15091512
return reshape_p.bind(
15101513
operand, *dyn_shape, new_sizes=tuple(static_new_sizes),
15111514
dimensions=None if dims is None or same_dims else dims,
@@ -1947,19 +1950,20 @@ def iota(dtype: DTypeLike, size: int) -> Array:
19471950
return broadcasted_iota(dtype, (size,), 0)
19481951

19491952
def broadcasted_iota(dtype: DTypeLike, shape: Shape, dimension: int,
1950-
_sharding=None) -> Array:
1953+
sharding=None) -> Array:
19511954
"""Convenience wrapper around ``iota``."""
19521955
dtype = dtypes.canonicalize_dtype(dtype)
19531956
shape = canonicalize_shape(shape)
19541957
dynamic_shape = [d for d in shape if isinstance(d, core.Tracer)]
19551958
static_shape = [None if isinstance(d, core.Tracer) else d for d in shape]
19561959
dimension = core.concrete_or_error(
19571960
int, dimension, "dimension argument of lax.broadcasted_iota")
1958-
if not config.sharding_in_types.value and _sharding is not None:
1961+
if not config.sharding_in_types.value and sharding is not None:
19591962
raise NotImplementedError('sharding support for broadcasted_iota is not '
19601963
'implemented outside of sharding_in_types mode.')
1964+
sharding = canonicalize_sharding(sharding)
19611965
return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape),
1962-
dimension=dimension, sharding=_sharding)
1966+
dimension=dimension, sharding=sharding)
19631967

19641968
def _eye(dtype: DTypeLike, shape: Shape, offset: DimSize = 0) -> Array:
19651969
"""Like numpy.eye, create a 2D array with ones on a diagonal."""
@@ -5560,7 +5564,7 @@ def _compute_argminmax(value_comparator, get_identity,
55605564
axis, = axes
55615565
indices = broadcasted_iota(
55625566
index_dtype, np.shape(operand), axis,
5563-
_sharding=operand.sharding if config.sharding_in_types.value else None)
5567+
sharding=operand.sharding if config.sharding_in_types.value else None)
55645568
res = reduce([operand, indices],
55655569
[get_identity(operand.dtype), np.array(0, index_dtype)],
55665570
_ArgMinMaxReducer(value_comparator),

jax/_src/nn/functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -671,7 +671,7 @@ def _one_hot(x: Array, num_classes: int, *,
671671
else:
672672
rhs_sharding = None
673673
rhs = lax.broadcasted_iota(x.dtype, rhs_shape, output_pos_axis,
674-
_sharding=rhs_sharding)
674+
sharding=rhs_sharding)
675675
return (lhs == rhs).astype(dtype)
676676

677677
# TODO(slebedev): Change the type of `x` to `ArrayLike`.

jax/_src/numpy/lax_numpy.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
from jax._src.core import ShapedArray
5555
from jax._src.custom_derivatives import custom_jvp
5656
from jax._src.lax import lax as lax_internal
57-
from jax._src.lax.lax import ( PrecisionLike,_array_copy,
57+
from jax._src.lax.lax import (PrecisionLike,_array_copy,
5858
_sort_le_comparator, _sort_lt_comparator)
5959
from jax._src.lib import xla_client as xc
6060
from jax._src.numpy import reductions
@@ -69,8 +69,9 @@
6969
NumpyComplexWarning, canonicalize_axis as _canonicalize_axis,
7070
ceil_of_ratio, partition_list, safe_zip, set_module, unzip2,
7171
tuple_replace)
72-
from jax.sharding import (Sharding, SingleDeviceSharding, NamedSharding,
73-
PartitionSpec as P)
72+
from jax.sharding import Sharding
73+
from jax._src.sharding_impls import (SingleDeviceSharding, NamedSharding,
74+
PartitionSpec as P, canonicalize_sharding)
7475
from jax.tree_util import tree_flatten, tree_leaves, tree_map
7576
import numpy as np
7677
import opt_einsum
@@ -9873,6 +9874,7 @@ def _einsum(
98739874
if out_type is not None and not config.sharding_in_types.value:
98749875
raise NotImplementedError("out_type only works when sharding_in_types "
98759876
"config is True.")
9877+
out_type = canonicalize_sharding(out_type)
98769878
if out_type is not None and not isinstance(out_type, NamedSharding):
98779879
raise NotImplementedError(
98789880
"`out_type` argument of `einsum` only supports NamedSharding instances."

jax/_src/pjit.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@
6767
from jax._src.sharding_impls import (
6868
NamedSharding, GSPMDSharding,
6969
SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue,
70-
ParsedPartitionSpec, get_single_pspec, prepare_axis_resources, parse_flatten_op_sharding)
70+
ParsedPartitionSpec, get_single_pspec, prepare_axis_resources,
71+
parse_flatten_op_sharding, canonicalize_sharding)
7172
from jax._src.layout import Layout, DeviceLocalLayout, AutoLayout
7273
from jax._src.state import discharge as state_discharge, RefEffect, AbstractRef
7374
from jax._src.traceback_util import api_boundary
@@ -2670,13 +2671,20 @@ def _sharding_constraint_batcher(
26702671

26712672
def sharding_cast(xs, shardings):
26722673
if isinstance(shardings, NamedSharding):
2673-
return tree_map(lambda x: sharding_cast_p.bind(
2674-
x, src_sharding=x.sharding, dst_sharding=shardings), xs)
2674+
return tree_map(
2675+
lambda x: sharding_cast_p.bind(
2676+
x, src_sharding=x.sharding, dst_sharding=canonicalize_sharding(
2677+
shardings, check_mesh_consistency=False)),
2678+
xs)
26752679

26762680
x_flat, treedef = tree_flatten(xs)
26772681
shardings_flat = flatten_axes("sharding_cast shardings", treedef, shardings)
2678-
out_flat = [sharding_cast_p.bind(x, src_sharding=x.sharding, dst_sharding=s)
2679-
for x, s in safe_zip(x_flat, shardings_flat)]
2682+
out_flat = [
2683+
sharding_cast_p.bind(
2684+
x, src_sharding=x.sharding,
2685+
dst_sharding=canonicalize_sharding(s, check_mesh_consistency=False))
2686+
for x, s in safe_zip(x_flat, shardings_flat)
2687+
]
26802688
return tree_unflatten(treedef, out_flat)
26812689

26822690
sharding_cast_p = core.Primitive('sharding_cast')

jax/_src/sharding_impls.py

Lines changed: 58 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,13 @@
2424
from typing import Any, NamedTuple, Union, cast
2525

2626
from jax._src import core
27+
from jax._src import config
2728
from jax._src import mesh as mesh_lib
28-
from jax._src import sharding
29+
from jax._src import sharding as jsharding
2930
from jax._src import sharding_specs
3031
from jax._src import tree_util
3132
from jax._src import util
33+
from jax._src import source_info_util
3234
from jax._src import xla_bridge
3335
from jax._src import mesh_utils
3436
from jax._src.lib import xla_client as xc
@@ -45,7 +47,7 @@
4547
Index = tuple[slice, ...]
4648
XLADeviceAssignment = tuple[Device, ...]
4749
# TODO(yashkatariya): Remove this after 3 months of deprecation.
48-
XLACompatibleSharding = sharding.Sharding
50+
XLACompatibleSharding = jsharding.Sharding
4951

5052
@dataclasses.dataclass(frozen=True)
5153
class TransferToMemoryKind:
@@ -219,7 +221,7 @@ def named_sharding_to_xla_hlo_sharding(
219221

220222

221223
@use_cpp_class(xc.NamedSharding)
222-
class NamedSharding(sharding.Sharding):
224+
class NamedSharding(jsharding.Sharding):
223225
r"""A :class:`NamedSharding` expresses sharding using named axes.
224226
225227
A :class:`NamedSharding` is a pair of a :class:`Mesh` of devices and
@@ -388,9 +390,6 @@ def with_spec(self, spec: PartitionSpec | Sequence[Any]) -> NamedSharding:
388390
spec = PartitionSpec(*spec)
389391
return NamedSharding(self.mesh, spec, memory_kind=self.memory_kind)
390392

391-
def with_mesh(self, new_mesh: mesh_lib.Mesh) -> NamedSharding:
392-
return NamedSharding(new_mesh, self.spec, memory_kind=self.memory_kind)
393-
394393
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
395394
return named_sharding_to_xla_hlo_sharding(self, num_dimensions)
396395

@@ -415,7 +414,7 @@ def get_replicated_hlo_sharding():
415414

416415

417416
@use_cpp_class(xc.SingleDeviceSharding)
418-
class SingleDeviceSharding(sharding.Sharding):
417+
class SingleDeviceSharding(jsharding.Sharding):
419418
"""A :class:`Sharding` that places its data on a single device.
420419
421420
Args:
@@ -503,7 +502,7 @@ def pmap_sharding_devices_indices_map(
503502

504503

505504
@use_cpp_class(xc.PmapSharding)
506-
class PmapSharding(sharding.Sharding):
505+
class PmapSharding(jsharding.Sharding):
507506
"""Describes a sharding used by :func:`jax.pmap`."""
508507
devices: np.ndarray
509508
sharding_spec: sharding_specs.ShardingSpec
@@ -713,7 +712,7 @@ def _positional_sharding_to_xla_hlo_sharding(
713712
return xc.HloSharding.from_proto(pbuf)
714713

715714

716-
class PositionalSharding(sharding.Sharding):
715+
class PositionalSharding(jsharding.Sharding):
717716
_devices: tuple[xc.Device, ...]
718717
_memory_kind: str | None
719718
_ids: np.ndarray # dtype DeviceIdSet
@@ -820,7 +819,7 @@ def with_memory_kind(self, kind: str) -> PositionalSharding:
820819
def is_fully_replicated(self) -> bool:
821820
return self.shape == (1,) * self.ndim
822821

823-
# sharding.Sharding interface
822+
# jsharding.Sharding interface
824823

825824
@property
826825
def _device_assignment(self) -> XLADeviceAssignment:
@@ -868,7 +867,7 @@ def __eq__(self, other) -> bool:
868867

869868

870869
@use_cpp_class(xc.GSPMDSharding)
871-
class GSPMDSharding(sharding.Sharding):
870+
class GSPMDSharding(jsharding.Sharding):
872871
_devices: tuple[Device, ...]
873872
_hlo_sharding: xc.HloSharding
874873
_memory_kind: str | None
@@ -1122,7 +1121,7 @@ def prepare_axis_resources(axis_resources, arg_name,
11221121
for entry in entries:
11231122
if isinstance(entry, (UnspecifiedValue, AUTO)) or entry is None:
11241123
new_entries.append(entry)
1125-
elif isinstance(entry, sharding.Sharding):
1124+
elif isinstance(entry, jsharding.Sharding):
11261125
if isinstance(entry, PmapSharding):
11271126
raise ValueError(f'One of {what} got sharding {entry} which is not '
11281127
'allowed.')
@@ -1138,7 +1137,7 @@ def prepare_axis_resources(axis_resources, arg_name,
11381137
def _check_unique_resources(axis_resources, arg_name):
11391138
for arg_axis_resources in axis_resources:
11401139
if not arg_axis_resources: continue
1141-
if isinstance(arg_axis_resources, (UnspecifiedValue, AUTO, sharding.Sharding)):
1140+
if isinstance(arg_axis_resources, (UnspecifiedValue, AUTO, jsharding.Sharding)):
11421141
continue
11431142
constrained_dims = [d for d in arg_axis_resources if d is not None]
11441143
resource_counts = collections.Counter(
@@ -1371,7 +1370,7 @@ class NonUniformShardingError(ValueError):
13711370

13721371

13731372
def get_process_index_and_count(
1374-
tensor_sharding: sharding.Sharding, dim: int, ndims: int) -> tuple[int, int]:
1373+
tensor_sharding: jsharding.Sharding, dim: int, ndims: int) -> tuple[int, int]:
13751374
"""Get current process index and number of unique processes for given dimension.
13761375
13771376
This function facilitates mapping of process-level data to individual
@@ -1486,7 +1485,7 @@ def get_process_index_and_count(
14861485

14871486

14881487
def local_to_global_shape(
1489-
sharding: sharding.Sharding, local_shape: Shape) -> tuple[int | None, ...]:
1488+
sharding: jsharding.Sharding, local_shape: Shape) -> tuple[int | None, ...]:
14901489
"""Computes the global shape given the per process if possible.
14911490
14921491
The returned shape will have the size of the global tensor in that dimension
@@ -1545,7 +1544,7 @@ def local_to_global_shape(
15451544

15461545

15471546
def num_addressable_indices(
1548-
tensor_sharding: sharding.Sharding, dim: int, global_shape: Shape) -> int:
1547+
tensor_sharding: jsharding.Sharding, dim: int, global_shape: Shape) -> int:
15491548
"""Returns the number of indices for given dimension this host has access to.
15501549
15511550
Each host can have multiple number of devices that are spanning
@@ -1579,7 +1578,7 @@ def num_addressable_indices(
15791578
"""
15801579
# TODO(sandler, yashkatariya): Consider making this function public.
15811580
addressables = tensor_sharding.addressable_devices_indices_map(global_shape)
1582-
addressables = cast(Mapping[sharding.Device, Index], addressables)
1581+
addressables = cast(Mapping[jsharding.Device, Index], addressables)
15831582
num_unique_slices = len({
15841583
_slice_as_tuple(addressable[dim]) for addressable in addressables.values()
15851584
})
@@ -1596,7 +1595,7 @@ def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding:
15961595
new_op_sharding.tile_assignment_dimensions = tad
15971596
return xc.HloSharding.from_proto(new_op_sharding)
15981597

1599-
def is_single_device_sharding(sharding: sharding.Sharding) -> bool:
1598+
def is_single_device_sharding(sharding: jsharding.Sharding) -> bool:
16001599
# Special case PmapSharding here because PmapSharding maps away an axis
16011600
# and needs to be handled separately.test_pjit_single_device_sharding_add
16021601
return sharding.num_devices == 1 and not isinstance(sharding, PmapSharding)
@@ -1625,7 +1624,7 @@ def make_key_array_phys_sharding(aval, sharding):
16251624

16261625

16271626
def physical_sharding(
1628-
aval, sharding: sharding.Sharding) -> sharding.Sharding:
1627+
aval, sharding: jsharding.Sharding) -> jsharding.Sharding:
16291628
return make_key_array_phys_sharding(aval, sharding)
16301629

16311630

@@ -1642,7 +1641,7 @@ def get_logical_gspmd_sharding(aval, phys_sharding):
16421641
return GSPMDSharding(phys_sharding._device_assignment,
16431642
xc.HloSharding.from_proto(logical_op_sharding))
16441643

1645-
def check_replicated_trailing_dims(sharding: sharding.Sharding, aval):
1644+
def check_replicated_trailing_dims(sharding: jsharding.Sharding, aval):
16461645
if isinstance(sharding, PmapSharding):
16471646
return
16481647
phys_aval = core.physical_aval(aval)
@@ -1655,7 +1654,7 @@ def check_replicated_trailing_dims(sharding: sharding.Sharding, aval):
16551654
f" sharding: {sharding}, partitions: {partitions}, "
16561655
f"num_trailing_dims: {num_trailing_dims}")
16571656

1658-
def logical_sharding(aval, phys_sharding) -> sharding.Sharding:
1657+
def logical_sharding(aval, phys_sharding) -> jsharding.Sharding:
16591658
# The trailing dims should always be replicated.
16601659
check_replicated_trailing_dims(phys_sharding, aval)
16611660

@@ -1695,6 +1694,44 @@ def _gspmd_to_named_sharding_via_mesh(
16951694
mesh, parsed_pspec.get_partition_spec(), parsed_pspec,
16961695
out_s.memory_kind)
16971696

1697+
def flatten_spec(spec):
1698+
out = []
1699+
for s in spec:
1700+
if s is None:
1701+
continue
1702+
if isinstance(s, tuple):
1703+
out.extend(s)
1704+
else:
1705+
out.append(s)
1706+
return out
1707+
1708+
def canonicalize_sharding(sharding: NamedSharding | PartitionSpec | None,
1709+
check_mesh_consistency: bool = True
1710+
) -> NamedSharding | None:
1711+
if not config.sharding_in_types.value:
1712+
return sharding # type: ignore
1713+
if sharding is None:
1714+
return sharding
1715+
1716+
if isinstance(sharding, PartitionSpec):
1717+
sharding = NamedSharding(mesh_lib.get_abstract_mesh(), sharding) # type: ignore
1718+
else:
1719+
if (check_mesh_consistency and
1720+
sharding.mesh != mesh_lib.get_abstract_mesh()):
1721+
raise ValueError(
1722+
f'Context mesh {mesh_lib.get_abstract_mesh()} should match the mesh'
1723+
f' of sharding {sharding.mesh}. This error occurs at source: '
1724+
f' {source_info_util.summarize(source_info_util.current())}')
1725+
1726+
for s in flatten_spec(sharding.spec):
1727+
if sharding.mesh._name_to_type[s] in {
1728+
mesh_lib.AxisTypes.Auto, mesh_lib.AxisTypes.Collective}:
1729+
raise ValueError(
1730+
'PartitionSpec cannot contain axis names that are of type Auto or'
1731+
f' Collective. Got PartitionSpec: {sharding.spec} with axis name:'
1732+
f' {s} or type: {sharding.mesh._name_to_type[s]}')
1733+
return sharding
1734+
16981735

16991736
def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str],
17001737
*, devices: Sequence[xc.Device] | None = None,

0 commit comments

Comments
 (0)