Skip to content

Commit 4489303

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Delete ParsedPartitionSpec and preprocess function and do a couple more cleanups
PiperOrigin-RevId: 738503430
1 parent 04454b6 commit 4489303

File tree

4 files changed

+41
-146
lines changed

4 files changed

+41
-146
lines changed

jax/_src/interpreters/pxla.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2463,14 +2463,41 @@ def cost_analysis(self) -> dict[str, float]:
24632463
return xe.hlo_module_cost_analysis(backend, self.hlo().as_hlo_module())
24642464

24652465

2466+
def get_op_sharding_from_executable(
2467+
executable) -> tuple[Sequence[xc.OpSharding], Sequence[xc.OpSharding]]:
2468+
in_op_shardings: list[xc.OpSharding] = []
2469+
parameter_shardings_from_xla = executable.get_parameter_shardings()
2470+
if parameter_shardings_from_xla is not None:
2471+
in_op_shardings = parameter_shardings_from_xla
2472+
2473+
out_op_shardings: list[xc.OpSharding] = []
2474+
output_shardings_from_xla = executable.get_output_shardings()
2475+
if output_shardings_from_xla is not None:
2476+
out_op_shardings = output_shardings_from_xla
2477+
2478+
return in_op_shardings, out_op_shardings
2479+
2480+
2481+
def get_pspec_from_executable(
2482+
executable, mesh: Mesh
2483+
) -> tuple[tuple[PartitionSpec, ...], tuple[PartitionSpec, ...]]:
2484+
input_op_s, output_op_s = get_op_sharding_from_executable(executable)
2485+
in_pspec: list[PartitionSpec] = []
2486+
for s in input_op_s:
2487+
in_pspec.extend(sharding_impls.parse_flatten_op_sharding(s, mesh))
2488+
2489+
out_pspec: list[PartitionSpec] = []
2490+
for s in output_op_s:
2491+
out_pspec.extend(sharding_impls.parse_flatten_op_sharding(s, mesh))
2492+
return tuple(in_pspec), tuple(out_pspec)
2493+
2494+
24662495
def get_out_shardings_from_executable(
24672496
xla_executable,
24682497
device_assignment: Sequence[xc.Device],
24692498
num_out_avals: int,
24702499
num_ordered_effects: int,
24712500
) -> Sequence[sharding_impls.GSPMDSharding] | None:
2472-
from jax._src import pjit
2473-
24742501
try:
24752502
omk = xla_executable.get_output_memory_kinds()[0]
24762503
if num_ordered_effects > 0:
@@ -2486,7 +2513,7 @@ def get_out_shardings_from_executable(
24862513
return [sharding_impls.GSPMDSharding.get_replicated(device_assignment, memory_kind=mk)
24872514
for mk in omk]
24882515

2489-
_, out_op_shardings = pjit.get_op_sharding_from_executable(xla_executable)
2516+
_, out_op_shardings = get_op_sharding_from_executable(xla_executable)
24902517
if not out_op_shardings:
24912518
return None
24922519

@@ -2517,14 +2544,12 @@ def _get_in_shardings_from_xla(
25172544
num_ordered_effects: int
25182545
) -> Sequence[GSPMDSharding] | None:
25192546
"""Returns input shardings from XLA."""
2520-
from jax._src import pjit
2521-
25222547
# When the device assignment only has 1 device, SPMD partitioner will not run.
25232548
# Hence the op shardings will not be set on the `hlo_module`.
25242549
if len(device_assignment) == 1:
25252550
return [GSPMDSharding.get_replicated(device_assignment)] * num_in_avals
25262551

2527-
in_op_shardings, _ = pjit.get_op_sharding_from_executable(xla_executable)
2552+
in_op_shardings, _ = get_op_sharding_from_executable(xla_executable)
25282553
if not in_op_shardings:
25292554
return None
25302555

@@ -2543,9 +2568,7 @@ def _get_in_shardings_from_xla(
25432568
def _get_mesh_pspec_shardings_from_executable(
25442569
xla_executable, mesh: Mesh
25452570
) -> tuple[Sequence[NamedSharding], Sequence[NamedSharding]]:
2546-
from jax._src import pjit
2547-
2548-
in_pspec, out_pspec = pjit.get_pspec_from_executable(xla_executable, mesh)
2571+
in_pspec, out_pspec = get_pspec_from_executable(xla_executable, mesh)
25492572
return ([NamedSharding(mesh, i) for i in in_pspec],
25502573
[NamedSharding(mesh, o) for o in out_pspec])
25512574

jax/_src/named_sharding.py

Lines changed: 7 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@
2121
from typing import Any, Union
2222

2323
from jax._src import config
24-
from jax._src.util import use_cpp_class, cache, use_cpp_method, tuple_insert
24+
from jax._src.util import use_cpp_class, cache, use_cpp_method
2525
from jax._src.lib import xla_client as xc
2626
from jax._src.lib.mlir.dialects import sdy
2727
from jax._src import mesh as mesh_lib
28-
from jax._src.partition_spec import PartitionSpec, UnconstrainedSingleton
28+
from jax._src.partition_spec import PartitionSpec
2929
from jax._src import sharding as JSharding
3030
from jax._src import xla_bridge as xb
3131
import numpy as np
@@ -198,7 +198,7 @@ def is_fully_addressable(self) -> bool:
198198
# Speed up `is_fully_addressable` since there is a high chance that the
199199
# mesh across multiple NamedSharding objects will be the same.
200200
if config.enable_empty_arrays.value:
201-
client = self._internal_device_list[0].client
201+
client = self._internal_device_list[0].client # type: ignore
202202
return (len(self.mesh._process_indices) == 1 and
203203
next(iter(self.mesh._process_indices)) ==
204204
xb.process_index(client))
@@ -325,80 +325,6 @@ def __repr__(self):
325325
if self.replicated_axes else '')
326326
return f"SdyArraySharding([{dim_sharding_repr}]{device_id_repr}{rar})"
327327

328-
# TODO(yashkatariya): Remove this after jax 0.5.2 release
329-
class ParsedPartitionSpec:
330-
__slots__ = ('_user_spec', 'partitions')
331-
332-
_user_spec: PartitionSpec | None
333-
partitions: tuple[tuple[MeshAxisName, ...] | UnconstrainedSingleton, ...]
334-
335-
def __init__(self, user_spec, partitions):
336-
self._user_spec = user_spec
337-
assert None not in partitions, partitions
338-
self.partitions = tuple(partitions)
339-
340-
def get_partition_spec(self) -> PartitionSpec:
341-
if isinstance(self._user_spec, PartitionSpec):
342-
return self._user_spec
343-
else:
344-
return get_single_pspec(self)
345-
346-
def insert_axis_partitions(self, dim, val):
347-
parts = self.partitions
348-
too_short = dim - len(parts)
349-
if too_short > 0:
350-
parts += ((),) * too_short
351-
new_partitions = tuple_insert(parts, dim, val)
352-
return ParsedPartitionSpec(None, new_partitions)
353-
354-
@classmethod
355-
def from_user_input(
356-
cls,
357-
entry: PartitionSpec | None,
358-
arg_name: str,
359-
allow_unconstrained_dims: bool = False,
360-
) -> ParsedPartitionSpec:
361-
if entry is None:
362-
return cls(entry, ())
363-
if not isinstance(entry, PartitionSpec):
364-
raise TypeError(f"{arg_name} are expected to be "
365-
f"PartitionSpec instances or None, but got {entry}")
366-
axis_specs = []
367-
for axis_spec in entry:
368-
if axis_spec is None:
369-
axis_spec = ()
370-
elif isinstance(axis_spec, (list, tuple)):
371-
axis_spec = tuple(axis_spec)
372-
elif axis_spec is PartitionSpec.UNCONSTRAINED:
373-
if not allow_unconstrained_dims:
374-
raise ValueError(f"Unconstrained dims are not allowed: {entry}")
375-
axis_spec = PartitionSpec.UNCONSTRAINED
376-
else:
377-
axis_spec = (axis_spec,)
378-
axis_specs.append(axis_spec)
379-
new_entry = PartitionSpec(
380-
*[tuple(e) if isinstance(e, (list, tuple)) else e for e in entry])
381-
return cls(new_entry, axis_specs)
382-
383-
def __hash__(self):
384-
return hash(self.partitions)
385-
386-
def __eq__(self, other):
387-
if not isinstance(other, ParsedPartitionSpec):
388-
return False
389-
return self.partitions == other.partitions
390-
391-
def __len__(self):
392-
return len(self.partitions)
393-
394-
def __getitem__(self, i):
395-
return self.partitions[i]
396-
397-
def __iter__(self):
398-
return iter(self.partitions)
399-
400-
def __repr__(self):
401-
return f"ParsedPartitionSpec(partitions={self.partitions})"
402328

403329
@cache(max_size=4096, trace_context_in_key=False)
404330
def named_sharding_to_xla_hlo_sharding(
@@ -491,18 +417,8 @@ def array_mapping_to_axis_resources(array_mapping: ArrayMapping):
491417
partitions.append(None)
492418
return PartitionSpec(*partitions)
493419

494-
get_single_pspec = lambda p: array_mapping_to_axis_resources(get_array_mapping(p)) # type: ignore
495-
496-
# TODO(yashkatariya): Remove this after jax 0.5.2 release
497-
def preprocess(mesh, spec, parsed_pspec, _manual_axes=frozenset()):
498-
if parsed_pspec is None:
499-
spec = PartitionSpec() if spec is None else spec
500-
parsed_pspec = ParsedPartitionSpec.from_user_input(
501-
spec, "NamedSharding spec", allow_unconstrained_dims=True)
502-
_check_unique_resources(parsed_pspec, "NamedSharding spec", mesh)
503-
_check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes)
504-
return parsed_pspec
505420

421+
@cache(max_size=128, trace_context_in_key=False)
506422
def check_pspec(mesh, spec, _manual_axes=frozenset()):
507423
_check_unique_resources(spec, "NamedSharding spec", mesh)
508424
_check_mesh_resource_axis(mesh, spec, _manual_axes)
@@ -517,13 +433,10 @@ def __init__(self, message, mesh, pspec):
517433
def __str__(self):
518434
return f"{self.message}"
519435

520-
def _check_unique_resources(
521-
pspec: ParsedPartitionSpec | PartitionSpec, arg_name: str, mesh=None,
522-
) -> None:
436+
def _check_unique_resources(pspec: PartitionSpec, arg_name: str, mesh=None
437+
) -> None:
523438
resource_counts: dict[MeshAxisName, int] = {}
524439
duplicate = False
525-
pspec = (pspec.get_partition_spec() if isinstance(pspec, ParsedPartitionSpec)
526-
else pspec)
527440
for d in pspec:
528441
if d is PartitionSpec.UNCONSTRAINED or d is None:
529442
continue
@@ -542,10 +455,8 @@ def _check_unique_resources(
542455
f' for {mesh_lib.show_axes(multiple_uses)}'),
543456
mesh=mesh, pspec=pspec)
544457

545-
@cache(max_size=128, trace_context_in_key=False)
458+
546459
def _check_mesh_resource_axis(mesh, pspec, _manual_axes):
547-
pspec = (pspec.get_partition_spec() if isinstance(pspec, ParsedPartitionSpec)
548-
else pspec)
549460
for p in pspec:
550461
if p is PartitionSpec.UNCONSTRAINED or p is None:
551462
continue

jax/_src/pjit.py

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2905,41 +2905,3 @@ def get_unconstrained_dims(sharding: NamedSharding):
29052905
assert sharding.spec is not None
29062906
return {i for i, axes in enumerate(sharding.spec)
29072907
if axes is PartitionSpec.UNCONSTRAINED}
2908-
2909-
2910-
def get_op_sharding_from_executable(
2911-
executable) -> tuple[Sequence[xc.OpSharding], Sequence[xc.OpSharding]]:
2912-
in_op_shardings: list[xc.OpSharding] = []
2913-
parameter_shardings_from_xla = executable.get_parameter_shardings()
2914-
if parameter_shardings_from_xla is not None:
2915-
in_op_shardings = parameter_shardings_from_xla
2916-
2917-
out_op_shardings: list[xc.OpSharding] = []
2918-
output_shardings_from_xla = executable.get_output_shardings()
2919-
if output_shardings_from_xla is not None:
2920-
out_op_shardings = output_shardings_from_xla
2921-
2922-
return in_op_shardings, out_op_shardings
2923-
2924-
2925-
def _get_ppspec_from_executable(
2926-
executable, mesh
2927-
) -> tuple[Sequence[PartitionSpec], Sequence[PartitionSpec]]:
2928-
input_op_shardings, output_op_sharding = get_op_sharding_from_executable(
2929-
executable
2930-
)
2931-
in_pspec: list[PartitionSpec] = []
2932-
for s in input_op_shardings:
2933-
in_pspec.extend(parse_flatten_op_sharding(s, mesh))
2934-
2935-
out_pspec: list[PartitionSpec] = []
2936-
for s in output_op_sharding:
2937-
out_pspec.extend(parse_flatten_op_sharding(s, mesh))
2938-
return in_pspec, out_pspec
2939-
2940-
2941-
def get_pspec_from_executable(
2942-
executable, mesh: pxla.Mesh
2943-
) -> tuple[tuple[PartitionSpec, ...], tuple[PartitionSpec, ...]]:
2944-
in_pspec, out_pspec = _get_ppspec_from_executable(executable, mesh)
2945-
return tuple(in_pspec), tuple(out_pspec)

jax/_src/sharding_impls.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,9 @@
3737
from jax._src.lib.mlir.dialects import sdy
3838
from jax._src.named_sharding import ( # noqa: F401
3939
SdyArraySharding, SdyDimSharding, UnspecifiedValue, AUTO,
40-
ParsedPartitionSpec, _check_unique_resources, NamedSharding, UNSPECIFIED,
40+
_check_unique_resources, NamedSharding, UNSPECIFIED,
4141
ArrayMapping, ArrayMappingOrAutoOrUnspecified, get_array_mapping,
42-
array_mapping_to_axis_resources, get_single_pspec, preprocess,
43-
named_sharding_to_xla_hlo_sharding)
42+
array_mapping_to_axis_resources, named_sharding_to_xla_hlo_sharding)
4443
from jax._src.op_shardings import (
4544
are_op_shardings_equal, get_num_ways_dim_sharded, is_op_sharding_replicated)
4645
from jax._src.partition_spec import PartitionSpec

0 commit comments

Comments
 (0)