Skip to content

Commit cf308a8

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
Use an PartitionSpec.UNCONSTRAINED to represent unconstrained dimensions in ParsedPartitionSpec, rather than None.
This makes PartitionSpec and ParsedPartitionSpec more similar, and fixes some TODOs. PiperOrigin-RevId: 724927217
1 parent 8401d9b commit cf308a8

File tree

4 files changed

+25
-18
lines changed

4 files changed

+25
-18
lines changed

jax/_src/interpreters/mlir.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from jax._src.interpreters import partial_eval as pe
5050
from jax._src.interpreters import xla
5151
from jax._src.layout import AutoLayout, DeviceLocalLayout
52+
from jax._src.partition_spec import PartitionSpec
5253
from jax._src.sharding import Sharding as JSharding
5354
from jax._src.sharding_impls import (AUTO, NamedSharding,
5455
modify_sdy_sharding_wrt_axis_types,
@@ -1062,20 +1063,27 @@ def _get_mem_kind(s: JSharding | AUTO | None) -> str | None:
10621063
assert isinstance(s, JSharding)
10631064
return s.memory_kind
10641065

1066+
10651067
def contains_unconstrained(s):
1066-
return isinstance(s, NamedSharding) and None in s._parsed_pspec
1068+
return (
1069+
isinstance(s, NamedSharding)
1070+
and PartitionSpec.UNCONSTRAINED in s._parsed_pspec
1071+
)
1072+
10671073

10681074
def all_unconstrained(s, aval):
10691075
if isinstance(s, NamedSharding):
10701076
if aval.ndim != len(s._parsed_pspec):
10711077
return False
1072-
return all(p is None for p in s._parsed_pspec)
1078+
return all(p is PartitionSpec.UNCONSTRAINED for p in s._parsed_pspec)
10731079
return False
10741080

10751081
def _get_unconstrained_dimensions(s, aval):
10761082
us = contains_unconstrained(s)
1077-
return (us, all_unconstrained(s, aval),
1078-
({i for i, p in enumerate(s._parsed_pspec) if p is None} if us else None))
1083+
return (
1084+
us, all_unconstrained(s, aval),
1085+
({i for i, p in enumerate(s._parsed_pspec)
1086+
if p is PartitionSpec.UNCONSTRAINED} if us else None))
10791087

10801088
def lower_jaxpr_to_module(
10811089
module_name: str,

jax/_src/partition_spec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def _normalized_spec_for_aval(self, ndim: int) -> PartitionSpec:
5656
for p in self:
5757
if p is None:
5858
out.append(None)
59-
elif isinstance(p, UnconstrainedSingleton):
59+
elif p is _UNCONSTRAINED_PARTITION:
6060
out.append(None)
6161
elif isinstance(p, (list, tuple)):
6262
if len(p) == 1:

jax/_src/pjit.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2025,7 +2025,8 @@ def _pjit_batcher_for_sharding(
20252025
if sharding_impls.is_op_sharding_replicated(hlo_s):
20262026
return s
20272027
if isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh):
2028-
parsed_pspec = s._parsed_pspec.insert_axis_partitions(dim, None)
2028+
parsed_pspec = s._parsed_pspec.insert_axis_partitions(
2029+
dim, PartitionSpec.UNCONSTRAINED)
20292030
return NamedSharding._from_parsed_pspec(s.mesh, parsed_pspec)
20302031
new_op = hlo_s.to_proto().clone()
20312032
tad = list(new_op.tile_assignment_dimensions)
@@ -2659,7 +2660,6 @@ def _sharding_constraint_batcher(
26592660
f"{sharding.spec}")
26602661
x, = vals_in
26612662
d, = dims_in
2662-
# None means unconstrained in ParsedPartitionSpec
26632663
unconstrained_dims = {ud + (d <= ud) for ud in unconstrained_dims}
26642664
if axis_data.spmd_name is None:
26652665
unconstrained_dims.add(d)
@@ -2887,7 +2887,7 @@ def use_explicit_axes(*axes):
28872887
def get_unconstrained_dims(sharding: NamedSharding):
28882888
assert sharding._parsed_pspec is not None
28892889
return {i for i, axes in enumerate(sharding._parsed_pspec)
2890-
if axes is None}
2890+
if axes is PartitionSpec.UNCONSTRAINED}
28912891

28922892

28932893
def _get_partition_spec(

jax/_src/sharding_impls.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class TransferToMemoryKind:
5656
@util.cache(max_size=128, trace_context_in_key=False)
5757
def _check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes):
5858
for p in parsed_pspec:
59-
if p is not None:
59+
if p is not PartitionSpec.UNCONSTRAINED:
6060
for r in p:
6161
if r not in mesh.shape:
6262
raise ValueError(
@@ -71,7 +71,7 @@ def _check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes):
7171
@util.cache(max_size=128, trace_context_in_key=False)
7272
def _check_axis_type_consistency(mesh, parsed_pspec):
7373
for p in parsed_pspec:
74-
if p is not None:
74+
if p is not PartitionSpec.UNCONSTRAINED:
7575
if not all(mesh._name_to_type[p[0]] == mesh._name_to_type[r] for r in p):
7676
raise ValueError(
7777
'AxisTypes should be the same in a tuple subset of PartitionSpec:'
@@ -431,7 +431,7 @@ def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding:
431431
dim_shardings = [SdyDimSharding(axes=[], is_closed=True)
432432
for _ in range(num_dimensions)]
433433
for i, dim_spec in enumerate(self._parsed_pspec):
434-
if dim_spec is None:
434+
if dim_spec is PartitionSpec.UNCONSTRAINED:
435435
dim_shardings[i].is_closed = False
436436
elif not dim_spec:
437437
# Already empty and closed sharding.
@@ -1079,7 +1079,7 @@ def get_array_mapping(
10791079
return axis_resources
10801080
return OrderedDict((axis, i)
10811081
for i, axes in enumerate(axis_resources)
1082-
if axes is not None for axis in axes)
1082+
if axes is not PartitionSpec.UNCONSTRAINED for axis in axes)
10831083

10841084

10851085
get_single_pspec = lambda p: array_mapping_to_axis_resources(
@@ -1090,12 +1090,11 @@ class ParsedPartitionSpec:
10901090
__slots__ = ('_user_spec', 'partitions')
10911091

10921092
_user_spec: PartitionSpec | None
1093-
partitions: tuple[tuple[MeshAxisName, ...] | None, ...]
1093+
partitions: tuple[tuple[MeshAxisName, ...] | UnconstrainedSingleton, ...]
10941094

10951095
def __init__(self, user_spec, partitions):
10961096
self._user_spec = user_spec
1097-
# None in partitions represents unconstrained dim.
1098-
# TODO(yashkatariya): May use a sentinel value.
1097+
assert None not in partitions, partitions
10991098
self.partitions = tuple(partitions)
11001099

11011100
def get_partition_spec(self) -> PartitionSpec:
@@ -1130,10 +1129,10 @@ def from_user_input(
11301129
axis_spec = ()
11311130
elif isinstance(axis_spec, (list, tuple)):
11321131
axis_spec = tuple(axis_spec)
1133-
elif isinstance(axis_spec, UnconstrainedSingleton):
1132+
elif axis_spec is PartitionSpec.UNCONSTRAINED:
11341133
if not allow_unconstrained_dims:
11351134
raise ValueError(f"Unconstrained dims are not allowed: {entry}")
1136-
axis_spec = None
1135+
axis_spec = PartitionSpec.UNCONSTRAINED
11371136
else:
11381137
axis_spec = (axis_spec,)
11391138
axis_specs.append(axis_spec)
@@ -1204,7 +1203,7 @@ def _check_unique_resources(
12041203
resource_counts: dict[MeshAxisName, int] = {}
12051204
duplicate = False
12061205
for d in arg_axis_resources:
1207-
if d is not None:
1206+
if d is not PartitionSpec.UNCONSTRAINED:
12081207
for resource in d:
12091208
count = resource_counts.get(resource, 0)
12101209
if count > 0:

0 commit comments

Comments
 (0)