Skip to content

Commit 46fb400

Browse files
yashk2810charleshofer
authored andcommitted
Remove _manual_axes from NamedSharding since we can now track the manual axes on the mesh.
PiperOrigin-RevId: 748534841
1 parent f7e3fe6 commit 46fb400

File tree

8 files changed

+39
-54
lines changed

8 files changed

+39
-54
lines changed

jax/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,7 @@ pytype_strict_library(
609609
":dtypes",
610610
":effects",
611611
":layout",
612+
":mesh",
612613
":op_shardings",
613614
":partial_eval",
614615
":partition_spec",

jax/_src/interpreters/mlir.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from jax._src.interpreters import xla
5050
from jax._src.layout import AutoLayout, DeviceLocalLayout
5151
from jax._src.partition_spec import PartitionSpec
52+
from jax._src.mesh import AxisType
5253
from jax._src.sharding import Sharding as JSharding
5354
from jax._src.sharding_impls import (AUTO, NamedSharding,
5455
modify_sdy_sharding_wrt_axis_types,
@@ -1017,18 +1018,29 @@ class LoweringResult(NamedTuple):
10171018

10181019

10191020
def add_manual_axes(axis_ctx: sharding_impls.SPMDAxisContext, sharding, ndim):
1020-
mesh = axis_ctx.mesh
1021+
mesh = axis_ctx.mesh.abstract_mesh
1022+
sharding_mesh = sharding.mesh.abstract_mesh
10211023
if (isinstance(sharding, sharding_impls.NamedSharding) and
1022-
sharding.mesh.shape == mesh.shape):
1023-
return sharding_impls.NamedSharding(
1024-
sharding.mesh, sharding.spec, memory_kind=sharding.memory_kind,
1025-
_manual_axes=axis_ctx.manual_axes)
1024+
sharding_mesh.shape == mesh.shape):
1025+
out_mesh, spec = sharding_mesh, sharding.spec
10261026
else:
1027-
spec = sharding_impls.parse_flatten_op_sharding(
1027+
out_mesh, spec = mesh, sharding_impls.parse_flatten_op_sharding(
10281028
sharding._to_xla_hlo_sharding(ndim), mesh)[0]
1029-
return sharding_impls.NamedSharding(
1030-
mesh, spec, memory_kind=sharding.memory_kind,
1031-
_manual_axes=axis_ctx.manual_axes)
1029+
1030+
out_mesh = out_mesh.update_axis_types(
1031+
{a: AxisType.Manual for a in axis_ctx.manual_axes})
1032+
out = sharding_impls.NamedSharding(out_mesh, spec,
1033+
memory_kind=sharding.memory_kind)
1034+
manual_axes = out.mesh.manual_axes
1035+
if any(p in manual_axes for s in out.spec
1036+
if s is not None and s is not PartitionSpec.UNCONSTRAINED
1037+
for p in (s if isinstance(s, tuple) else (s,))):
1038+
raise ValueError(
1039+
f'pspec {out.spec} contains a manual axes {manual_axes} of mesh'
1040+
f' which is not allowed. If you are using a'
1041+
' with_sharding_constraint under a shard_map, only use the'
1042+
' mesh axis in PartitionSpec which are not manual.')
1043+
return out
10321044

10331045

10341046
def _to_physical_op_sharding(

jax/_src/named_sharding.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -112,20 +112,17 @@ class NamedSharding(JSharding.Sharding):
112112
mesh: mesh_lib.Mesh | mesh_lib.AbstractMesh
113113
spec: PartitionSpec
114114
_memory_kind: str | None
115-
_manual_axes: frozenset[MeshAxisName]
116115
_logical_device_ids: tuple[int, ...] | None
117116

118117
@use_cpp_method()
119118
def __init__(
120119
self, mesh: mesh_lib.Mesh | mesh_lib.AbstractMesh, spec: PartitionSpec, *,
121-
memory_kind: str | None = None, _manual_axes=frozenset(),
122-
_logical_device_ids=None):
120+
memory_kind: str | None = None, _logical_device_ids=None):
123121
self.mesh = mesh
124122
self.spec = spec
125123
self._memory_kind = memory_kind
126-
self._manual_axes = _manual_axes
127124
self._logical_device_ids = _logical_device_ids
128-
check_pspec(self.mesh, self.spec, self._manual_axes)
125+
check_pspec(self.mesh, self.spec)
129126

130127
def __repr__(self):
131128
mem = '' if self.memory_kind is None else f', memory_kind={self.memory_kind}'
@@ -137,7 +134,6 @@ def __repr__(self):
137134
def __reduce__(self):
138135
return (type(self), (self.mesh, self.spec),
139136
{'memory_kind': self.memory_kind,
140-
'_manual_axes': self._manual_axes,
141137
'_logical_device_ids': self._logical_device_ids})
142138

143139
@property
@@ -147,8 +143,7 @@ def memory_kind(self) -> str | None:
147143
def __hash__(self):
148144
if not hasattr(self, '_hash'):
149145
self._hash = hash(
150-
(self.mesh, self.memory_kind, self.spec, self._manual_axes,
151-
self._logical_device_ids))
146+
(self.mesh, self.memory_kind, self.spec, self._logical_device_ids))
152147
return self._hash
153148

154149
def __eq__(self, other):
@@ -158,7 +153,6 @@ def __eq__(self, other):
158153
return True
159154
if (self.spec != other.spec
160155
or self.memory_kind != other.memory_kind
161-
or self._manual_axes != other._manual_axes
162156
or self._logical_device_ids != other._logical_device_ids):
163157
return False
164158
return self.mesh is other.mesh or self.mesh == other.mesh
@@ -333,9 +327,7 @@ def named_sharding_to_xla_hlo_sharding(
333327
mesh_axis_pos = {name: i for i, name in enumerate(self.mesh.axis_names)}
334328

335329
special_axes = {}
336-
mesh_manual_axes = {n for n, t in self.mesh._name_to_type.items()
337-
if t == mesh_lib.AxisType.Manual}
338-
manual_axes = self._manual_axes.union(mesh_manual_axes)
330+
manual_axes = frozenset(self.mesh.manual_axes)
339331
if manual_axes:
340332
axis_names = self.mesh.axis_names
341333
for manual_axis in manual_axes:
@@ -420,7 +412,7 @@ def array_mapping_to_axis_resources(array_mapping: ArrayMapping):
420412
@cache(max_size=128, trace_context_in_key=False)
421413
def check_pspec(mesh, spec, _manual_axes=frozenset()):
422414
_check_unique_resources(spec, "NamedSharding spec", mesh)
423-
_check_mesh_resource_axis(mesh, spec, _manual_axes)
415+
_check_mesh_resource_axis(mesh, spec)
424416

425417
class DuplicateSpecError(Exception):
426418
def __init__(self, message, mesh, pspec):
@@ -455,7 +447,7 @@ def _check_unique_resources(pspec: PartitionSpec, arg_name: str, mesh=None
455447
mesh=mesh, pspec=pspec)
456448

457449

458-
def _check_mesh_resource_axis(mesh, pspec, _manual_axes):
450+
def _check_mesh_resource_axis(mesh, pspec):
459451
for p in pspec:
460452
if p is PartitionSpec.UNCONSTRAINED or p is None:
461453
continue
@@ -465,10 +457,6 @@ def _check_mesh_resource_axis(mesh, pspec, _manual_axes):
465457
raise ValueError(
466458
f"Resource axis: {r} of {pspec} "
467459
f"is not found in mesh: {tuple(mesh.shape.keys())}.")
468-
if r in _manual_axes:
469-
raise ValueError(
470-
f"Axis: {r} of {pspec} "
471-
f"is also found in manual_axes: {_manual_axes}.") from None
472460
if not all(mesh._name_to_type[p[0]] == mesh._name_to_type[r] for r in p):
473461
raise ValueError(
474462
'AxisTypes should be the same in a tuple subset of PartitionSpec:'

jax/experimental/shard_map.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -785,13 +785,12 @@ def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names,
785785

786786
def _make_scoped_manual_sharding(ctx, mesh, axes):
787787
axis_ctx = ctx.module_context.axis_context
788+
mesh = mesh.abstract_mesh
788789
if isinstance(axis_ctx, sharding_impls.SPMDAxisContext):
789-
manual_axes = axis_ctx.manual_axes
790-
else:
791-
manual_axes = frozenset({})
790+
mesh = mesh.update_axis_types(
791+
{a: AxisType.Manual for a in axis_ctx.manual_axes})
792792
return NamedSharding(
793-
mesh, sharding_impls.array_mapping_to_axis_resources(axes), # pytype: disable=wrong-arg-types
794-
_manual_axes=manual_axes)
793+
mesh, sharding_impls.array_mapping_to_axis_resources(axes)) # type: ignore
795794

796795
def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
797796
aval_in, aval_out, x):

jaxlib/xla/sharding.cc

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,6 @@ bool ShardingEqual(nb::handle a, nb::handle b) {
167167
a_named_sharding->spec().equal(b_named_sharding->spec()) &&
168168
a_named_sharding->memory_kind().equal(
169169
b_named_sharding->memory_kind()) &&
170-
a_named_sharding->manual_axes().equal(
171-
b_named_sharding->manual_axes()) &&
172170
a_named_sharding->logical_device_ids().equal(
173171
b_named_sharding->logical_device_ids());
174172
}
@@ -204,15 +202,14 @@ static const std::array<absl::string_view, 3> valid_memory_kinds = {
204202
};
205203

206204
NamedSharding::NamedSharding(nb::object mesh, nb::object spec,
207-
nb::object memory_kind, nb::object manual_axes,
205+
nb::object memory_kind,
208206
nb::object logical_device_ids)
209207
: Sharding(/*num_devices=*/[&mesh]() {
210208
return nb::cast<int>(mesh.attr("size"));
211209
}()),
212210
mesh_(std::move(mesh)),
213211
spec_(std::move(spec)),
214212
memory_kind_(std::move(memory_kind)),
215-
manual_axes_(std::move(manual_axes)),
216213
logical_device_ids_(std::move(logical_device_ids)) {
217214
if (spec_.is_none()) {
218215
throw nb::type_error(
@@ -261,7 +258,7 @@ NamedSharding::NamedSharding(nb::object mesh, nb::object spec,
261258
}
262259
return output;
263260
}();
264-
(*check_pspec)(mesh_, spec_, manual_axes_);
261+
(*check_pspec)(mesh_, spec_);
265262
}
266263

267264
/*static*/ PyObject* NamedSharding::type_ = nullptr;
@@ -352,16 +349,13 @@ void RegisterSharding(nb::module_& m) {
352349
nb::class_<Sharding>(m, "Sharding").def(nb::init<>());
353350

354351
nb::class_<NamedSharding, Sharding>(m, "NamedSharding", nb::dynamic_attr())
355-
.def(nb::init<nb::object, nb::object, nb::object, nb::object,
356-
nb::object>(),
352+
.def(nb::init<nb::object, nb::object, nb::object, nb::object>(),
357353
nb::arg("mesh"), nb::arg("spec").none(),
358354
nb::arg("memory_kind").none() = nb::none(),
359-
nb::arg("_manual_axes") = nb::steal(PyFrozenSet_New(nullptr)),
360355
nb::arg("_logical_device_ids").none() = nb::none())
361356
.def_prop_ro("mesh", &NamedSharding::mesh)
362357
.def_prop_ro("spec", &NamedSharding::spec)
363358
.def_prop_ro("_memory_kind", &NamedSharding::memory_kind)
364-
.def_prop_ro("_manual_axes", &NamedSharding::manual_axes)
365359
.def_prop_ro("_logical_device_ids", &NamedSharding::logical_device_ids)
366360
.def_prop_ro("_internal_device_list", [](const NamedSharding& s) {
367361
return xla::ValueOrThrow(s.internal_device_list());

jaxlib/xla/sharding.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,12 @@ bool ShardingEqual(nanobind::handle a, nanobind::handle b);
7575
class NamedSharding : public Sharding {
7676
public:
7777
NamedSharding(nanobind::object mesh, nanobind::object spec,
78-
nanobind::object memory_kind, nanobind::object manual_axes,
78+
nanobind::object memory_kind,
7979
nanobind::object logical_device_ids);
8080

8181
const nanobind::object& mesh() const { return mesh_; }
8282
const nanobind::object& spec() const { return spec_; }
8383
const nanobind::object& memory_kind() const { return memory_kind_; }
84-
const nanobind::object& manual_axes() const { return manual_axes_; }
8584
const nanobind::object& logical_device_ids() const {
8685
return logical_device_ids_;
8786
}
@@ -102,7 +101,6 @@ class NamedSharding : public Sharding {
102101
nanobind::object mesh_;
103102
nanobind::object spec_;
104103
nanobind::object memory_kind_;
105-
nanobind::object manual_axes_;
106104
nanobind::object logical_device_ids_;
107105
std::optional<xla::nb_class_ptr<PyDeviceList>> internal_device_list_;
108106
static PyObject* type_;

jaxlib/xla/xla_extension/__init__.pyi

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -926,14 +926,12 @@ class NamedSharding(Sharding):
926926
spec: Any,
927927
*,
928928
memory_kind: Optional[str] = None,
929-
_manual_axes: frozenset[Any] = frozenset(),
930929
_logical_device_ids: tuple[int, ...] | None = None,
931930
): ...
932931
mesh: Any
933932
spec: Any
934933
_memory_kind: Optional[str]
935934
_internal_device_list: DeviceList
936-
_manual_axes: frozenset[Any]
937935
_logical_device_ids: tuple[int, ...] | None
938936

939937
class SingleDeviceSharding(Sharding):

tests/shard_map_test.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2234,17 +2234,12 @@ def g(x):
22342234
return x * x
22352235

22362236
def h(x):
2237-
return shard_map(g, mesh,
2238-
in_specs=P(None, 'j'),
2239-
out_specs=P(None, 'j'))(x)
2237+
return shard_map(g, mesh, in_specs=P(None, 'j'), out_specs=P(None, 'j'))(x)
22402238

22412239
@jax.jit
22422240
def f(x):
2243-
return shard_map(h, mesh,
2244-
in_specs=P('i', None),
2245-
out_specs=P('i', None),
2246-
check_rep=False,
2247-
auto=frozenset({'j'}))(x)
2241+
return shard_map(h, mesh, in_specs=P('i', None), out_specs=P('i', None),
2242+
check_rep=False, auto=frozenset({'j'}))(x)
22482243

22492244
v = jnp.arange(32.).reshape(4, 8)
22502245
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))

0 commit comments

Comments
 (0)