2424from typing import Any , NamedTuple , Union , cast
2525
2626from jax ._src import core
27+ from jax ._src import config
2728from jax ._src import mesh as mesh_lib
28- from jax ._src import sharding
29+ from jax ._src import sharding as jsharding
2930from jax ._src import sharding_specs
3031from jax ._src import tree_util
3132from jax ._src import util
33+ from jax ._src import source_info_util
3234from jax ._src import xla_bridge
3335from jax ._src import mesh_utils
3436from jax ._src .lib import xla_client as xc
4547Index = tuple [slice , ...]
4648XLADeviceAssignment = tuple [Device , ...]
4749# TODO(yashkatariya): Remove this after 3 months of deprecation.
48- XLACompatibleSharding = sharding .Sharding
50+ XLACompatibleSharding = jsharding .Sharding
4951
5052@dataclasses .dataclass (frozen = True )
5153class TransferToMemoryKind :
@@ -219,7 +221,7 @@ def named_sharding_to_xla_hlo_sharding(
219221
220222
221223@use_cpp_class (xc .NamedSharding )
222- class NamedSharding (sharding .Sharding ):
224+ class NamedSharding (jsharding .Sharding ):
223225 r"""A :class:`NamedSharding` expresses sharding using named axes.
224226
225227 A :class:`NamedSharding` is a pair of a :class:`Mesh` of devices and
@@ -388,9 +390,6 @@ def with_spec(self, spec: PartitionSpec | Sequence[Any]) -> NamedSharding:
388390 spec = PartitionSpec (* spec )
389391 return NamedSharding (self .mesh , spec , memory_kind = self .memory_kind )
390392
391- def with_mesh (self , new_mesh : mesh_lib .Mesh ) -> NamedSharding :
392- return NamedSharding (new_mesh , self .spec , memory_kind = self .memory_kind )
393-
394393 def _to_xla_hlo_sharding (self , num_dimensions : int ) -> xc .HloSharding :
395394 return named_sharding_to_xla_hlo_sharding (self , num_dimensions )
396395
@@ -415,7 +414,7 @@ def get_replicated_hlo_sharding():
415414
416415
417416@use_cpp_class (xc .SingleDeviceSharding )
418- class SingleDeviceSharding (sharding .Sharding ):
417+ class SingleDeviceSharding (jsharding .Sharding ):
419418 """A :class:`Sharding` that places its data on a single device.
420419
421420 Args:
@@ -503,7 +502,7 @@ def pmap_sharding_devices_indices_map(
503502
504503
505504@use_cpp_class (xc .PmapSharding )
506- class PmapSharding (sharding .Sharding ):
505+ class PmapSharding (jsharding .Sharding ):
507506 """Describes a sharding used by :func:`jax.pmap`."""
508507 devices : np .ndarray
509508 sharding_spec : sharding_specs .ShardingSpec
@@ -713,7 +712,7 @@ def _positional_sharding_to_xla_hlo_sharding(
713712 return xc .HloSharding .from_proto (pbuf )
714713
715714
716- class PositionalSharding (sharding .Sharding ):
715+ class PositionalSharding (jsharding .Sharding ):
717716 _devices : tuple [xc .Device , ...]
718717 _memory_kind : str | None
719718 _ids : np .ndarray # dtype DeviceIdSet
@@ -820,7 +819,7 @@ def with_memory_kind(self, kind: str) -> PositionalSharding:
820819 def is_fully_replicated (self ) -> bool :
821820 return self .shape == (1 ,) * self .ndim
822821
823- # sharding .Sharding interface
822+ # jsharding .Sharding interface
824823
825824 @property
826825 def _device_assignment (self ) -> XLADeviceAssignment :
@@ -868,7 +867,7 @@ def __eq__(self, other) -> bool:
868867
869868
870869@use_cpp_class (xc .GSPMDSharding )
871- class GSPMDSharding (sharding .Sharding ):
870+ class GSPMDSharding (jsharding .Sharding ):
872871 _devices : tuple [Device , ...]
873872 _hlo_sharding : xc .HloSharding
874873 _memory_kind : str | None
@@ -1122,7 +1121,7 @@ def prepare_axis_resources(axis_resources, arg_name,
11221121 for entry in entries :
11231122 if isinstance (entry , (UnspecifiedValue , AUTO )) or entry is None :
11241123 new_entries .append (entry )
1125- elif isinstance (entry , sharding .Sharding ):
1124+ elif isinstance (entry , jsharding .Sharding ):
11261125 if isinstance (entry , PmapSharding ):
11271126 raise ValueError (f'One of { what } got sharding { entry } which is not '
11281127 'allowed.' )
@@ -1138,7 +1137,7 @@ def prepare_axis_resources(axis_resources, arg_name,
11381137def _check_unique_resources (axis_resources , arg_name ):
11391138 for arg_axis_resources in axis_resources :
11401139 if not arg_axis_resources : continue
1141- if isinstance (arg_axis_resources , (UnspecifiedValue , AUTO , sharding .Sharding )):
1140+ if isinstance (arg_axis_resources , (UnspecifiedValue , AUTO , jsharding .Sharding )):
11421141 continue
11431142 constrained_dims = [d for d in arg_axis_resources if d is not None ]
11441143 resource_counts = collections .Counter (
@@ -1371,7 +1370,7 @@ class NonUniformShardingError(ValueError):
13711370
13721371
13731372def get_process_index_and_count (
1374- tensor_sharding : sharding .Sharding , dim : int , ndims : int ) -> tuple [int , int ]:
1373+ tensor_sharding : jsharding .Sharding , dim : int , ndims : int ) -> tuple [int , int ]:
13751374 """Get current process index and number of unique processes for given dimension.
13761375
13771376 This function facilitates mapping of process-level data to individual
@@ -1486,7 +1485,7 @@ def get_process_index_and_count(
14861485
14871486
14881487def local_to_global_shape (
1489- sharding : sharding .Sharding , local_shape : Shape ) -> tuple [int | None , ...]:
1488+ sharding : jsharding .Sharding , local_shape : Shape ) -> tuple [int | None , ...]:
14901489 """Computes the global shape given the per process if possible.
14911490
14921491 The returned shape will have the size of the global tensor in that dimension
@@ -1545,7 +1544,7 @@ def local_to_global_shape(
15451544
15461545
15471546def num_addressable_indices (
1548- tensor_sharding : sharding .Sharding , dim : int , global_shape : Shape ) -> int :
1547+ tensor_sharding : jsharding .Sharding , dim : int , global_shape : Shape ) -> int :
15491548 """Returns the number of indices for given dimension this host has access to.
15501549
15511550 Each host can have multiple number of devices that are spanning
@@ -1579,7 +1578,7 @@ def num_addressable_indices(
15791578 """
15801579 # TODO(sandler, yashkatariya): Consider making this function public.
15811580 addressables = tensor_sharding .addressable_devices_indices_map (global_shape )
1582- addressables = cast (Mapping [sharding .Device , Index ], addressables )
1581+ addressables = cast (Mapping [jsharding .Device , Index ], addressables )
15831582 num_unique_slices = len ({
15841583 _slice_as_tuple (addressable [dim ]) for addressable in addressables .values ()
15851584 })
@@ -1596,7 +1595,7 @@ def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding:
15961595 new_op_sharding .tile_assignment_dimensions = tad
15971596 return xc .HloSharding .from_proto (new_op_sharding )
15981597
1599- def is_single_device_sharding (sharding : sharding .Sharding ) -> bool :
1598+ def is_single_device_sharding (sharding : jsharding .Sharding ) -> bool :
16001599 # Special case PmapSharding here because PmapSharding maps away an axis
16011600 # and needs to be handled separately.test_pjit_single_device_sharding_add
16021601 return sharding .num_devices == 1 and not isinstance (sharding , PmapSharding )
@@ -1625,7 +1624,7 @@ def make_key_array_phys_sharding(aval, sharding):
16251624
16261625
16271626def physical_sharding (
1628- aval , sharding : sharding .Sharding ) -> sharding .Sharding :
1627+ aval , sharding : jsharding .Sharding ) -> jsharding .Sharding :
16291628 return make_key_array_phys_sharding (aval , sharding )
16301629
16311630
@@ -1642,7 +1641,7 @@ def get_logical_gspmd_sharding(aval, phys_sharding):
16421641 return GSPMDSharding (phys_sharding ._device_assignment ,
16431642 xc .HloSharding .from_proto (logical_op_sharding ))
16441643
1645- def check_replicated_trailing_dims (sharding : sharding .Sharding , aval ):
1644+ def check_replicated_trailing_dims (sharding : jsharding .Sharding , aval ):
16461645 if isinstance (sharding , PmapSharding ):
16471646 return
16481647 phys_aval = core .physical_aval (aval )
@@ -1655,7 +1654,7 @@ def check_replicated_trailing_dims(sharding: sharding.Sharding, aval):
16551654 f" sharding: { sharding } , partitions: { partitions } , "
16561655 f"num_trailing_dims: { num_trailing_dims } " )
16571656
1658- def logical_sharding (aval , phys_sharding ) -> sharding .Sharding :
1657+ def logical_sharding (aval , phys_sharding ) -> jsharding .Sharding :
16591658 # The trailing dims should always be replicated.
16601659 check_replicated_trailing_dims (phys_sharding , aval )
16611660
@@ -1695,6 +1694,44 @@ def _gspmd_to_named_sharding_via_mesh(
16951694 mesh , parsed_pspec .get_partition_spec (), parsed_pspec ,
16961695 out_s .memory_kind )
16971696
1697+ def flatten_spec (spec ):
1698+ out = []
1699+ for s in spec :
1700+ if s is None :
1701+ continue
1702+ if isinstance (s , tuple ):
1703+ out .extend (s )
1704+ else :
1705+ out .append (s )
1706+ return out
1707+
1708+ def canonicalize_sharding (sharding : NamedSharding | PartitionSpec | None ,
1709+ check_mesh_consistency : bool = True
1710+ ) -> NamedSharding | None :
1711+ if not config .sharding_in_types .value :
1712+ return sharding # type: ignore
1713+ if sharding is None :
1714+ return sharding
1715+
1716+ if isinstance (sharding , PartitionSpec ):
1717+ sharding = NamedSharding (mesh_lib .get_abstract_mesh (), sharding ) # type: ignore
1718+ else :
1719+ if (check_mesh_consistency and
1720+ sharding .mesh != mesh_lib .get_abstract_mesh ()):
1721+ raise ValueError (
1722+ f'Context mesh { mesh_lib .get_abstract_mesh ()} should match the mesh'
1723+ f' of sharding { sharding .mesh } . This error occurs at source: '
1724+ f' { source_info_util .summarize (source_info_util .current ())} ' )
1725+
1726+ for s in flatten_spec (sharding .spec ):
1727+ if sharding .mesh ._name_to_type [s ] in {
1728+ mesh_lib .AxisTypes .Auto , mesh_lib .AxisTypes .Collective }:
1729+ raise ValueError (
1730+ 'PartitionSpec cannot contain axis names that are of type Auto or'
1731+ f' Collective. Got PartitionSpec: { sharding .spec } with axis name:'
1732+ f' { s } or type: { sharding .mesh ._name_to_type [s ]} ' )
1733+ return sharding
1734+
16981735
16991736def make_mesh (axis_shapes : Sequence [int ], axis_names : Sequence [str ],
17001737 * , devices : Sequence [xc .Device ] | None = None ,
0 commit comments