Skip to content

Commit d0f63da

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Allow tracing and lowering (with lowering_platforms specified) to work with an AbstractMesh. Such a computation cannot be compiled.
This is useful for `jax.export`, e.g., for cross-platform export when we do not have access to the actual devices for which this computation is lowered. PiperOrigin-RevId: 705764178
1 parent 0e7f218 commit d0f63da

File tree

3 files changed

+128
-31
lines changed

3 files changed

+128
-31
lines changed

jax/_src/interpreters/pxla.py

Lines changed: 79 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
from jax._src import dtypes
4242
from jax._src import effects
4343
from jax._src import linear_util as lu
44-
from jax._src import mesh as mesh_lib
4544
from jax._src import op_shardings
4645
from jax._src import sharding_specs
4746
from jax._src import profiler
@@ -65,6 +64,7 @@
6564
from jax._src.lib.mlir.dialects import hlo
6665
from jax._src.partition_spec import PartitionSpec, UnconstrainedSingleton
6766
from jax._src.sharding import Sharding as JSharding
67+
from jax._src.mesh import AbstractMesh, Mesh
6868
from jax._src.sharding_impls import (
6969
ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UNSPECIFIED,
7070
UnspecifiedValue, get_array_mapping as _get_array_mapping,
@@ -98,7 +98,6 @@ class WeakRefList(list):
9898
Replicated = sharding_specs.Replicated
9999

100100
AvalDimSharding = Union[Unstacked, Chunked, NoSharding]
101-
Mesh = mesh_lib.Mesh
102101
MeshAxisName = sharding_impls.MeshAxisName
103102
MeshDimAssignment = Union[ShardedAxis, Replicated]
104103
ShardingSpec = sharding_specs.ShardingSpec
@@ -1723,20 +1722,19 @@ def _get_and_check_device_assignment(
17231722
devices: Sequence[xc.Device] | None,
17241723
) -> tuple[xc.Client, tuple[xc.Device, ...]]:
17251724
first_sharding_info = None
1726-
if devices is None:
1727-
devices = ()
1728-
else:
1729-
devices = tuple(devices)
1725+
devices = () if devices is None else tuple(devices)
17301726

1731-
for i, s_type, source_info in shardings:
1732-
if isinstance(i, UnspecifiedValue):
1727+
for sh, s_type, source_info in shardings:
1728+
if isinstance(sh, UnspecifiedValue):
1729+
continue
1730+
if isinstance(sh, NamedSharding) and isinstance(sh.mesh, AbstractMesh):
17331731
continue
1734-
17351732
if first_sharding_info is None:
17361733
first_sharding_info = (
1737-
(i.mesh._flat_devices_tuple, s_type, source_info) if isinstance(i, AUTO)
1738-
else (i._device_assignment, s_type, source_info))
1739-
arr_device_assignment = i.mesh._flat_devices_tuple if isinstance(i, AUTO) else i._device_assignment
1734+
(sh.mesh._flat_devices_tuple, s_type, source_info) if isinstance(sh, AUTO)
1735+
else (sh._device_assignment, s_type, source_info))
1736+
arr_device_assignment = (sh.mesh._flat_devices_tuple if isinstance(sh, AUTO)
1737+
else sh._device_assignment)
17401738
if not devices:
17411739
if first_sharding_info[0] != arr_device_assignment:
17421740
raise DeviceAssignmentMismatchError([
@@ -1837,7 +1835,8 @@ class SemanticallyEqualShardings:
18371835
def __init__(self, shardings: tuple[GSPMDSharding | UnspecifiedValue, ...],
18381836
avals: tuple[core.AbstractValue]):
18391837
gspmd_shardings = [
1840-
s if isinstance(s, (UnspecifiedValue, AUTO))
1838+
s if (isinstance(s, (UnspecifiedValue, AUTO)) or
1839+
(isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh)))
18411840
else to_gspmd_sharding(s, a.ndim) # pytype: disable=attribute-error
18421841
for s, a in zip(shardings, avals)]
18431842
self._gspmd_shardings = gspmd_shardings
@@ -1895,7 +1894,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
18951894
propagated_out_mem_kinds: tuple[None | str, ...],
18961895
platforms: tuple[str, ...],
18971896
lowering_parameters: mlir.LoweringParameters,
1898-
abstract_mesh: mesh_lib.AbstractMesh | None):
1897+
abstract_mesh: AbstractMesh | None):
18991898
jaxpr = closed_jaxpr.jaxpr
19001899
in_shardings = semantic_in_shardings.shardings
19011900
out_shardings = semantic_out_shardings.shardings
@@ -2082,6 +2081,40 @@ def write(var, val):
20822081
return tuple(safe_map(read, jaxpr.outvars))
20832082

20842083

2084+
def _get_num_devices(shardings, device_assignment, lowering_platforms,
2085+
prim_requires_devices) -> int:
2086+
ext_abstract_mesh, concrete_sharding = None, False
2087+
for s in shardings:
2088+
if isinstance(s, UnspecifiedValue):
2089+
continue
2090+
elif isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh):
2091+
if ext_abstract_mesh is not None and ext_abstract_mesh != s.mesh:
2092+
raise ValueError("AbstractMesh should be the same across all "
2093+
f"shardings. Got {ext_abstract_mesh} and {s.mesh}")
2094+
ext_abstract_mesh = s.mesh
2095+
else:
2096+
concrete_sharding = True
2097+
if (concrete_sharding and ext_abstract_mesh is not None and
2098+
len(device_assignment) != ext_abstract_mesh.size):
2099+
raise ValueError(
2100+
f"AbstractMesh size: {ext_abstract_mesh.size} does not match the"
2101+
f" device assignment size: {len(device_assignment)}")
2102+
if concrete_sharding:
2103+
return len(device_assignment)
2104+
if ext_abstract_mesh is None:
2105+
return len(device_assignment)
2106+
if lowering_platforms is None:
2107+
raise ValueError(
2108+
"Passing lowering_platforms via"
2109+
" jit(f).trace(*args).lower(lowering_platforms=...) is required when"
2110+
" only AbstractMesh exists in a jitted computation.")
2111+
if prim_requires_devices:
2112+
raise ValueError(
2113+
"AbstractMesh cannot be used when jaxpr contains primitives that"
2114+
" require devices to be present during lowering.")
2115+
return ext_abstract_mesh.size
2116+
2117+
20852118
MaybeLayout = Sequence[Union[DeviceLocalLayout, AutoLayout, None]]
20862119

20872120

@@ -2126,7 +2159,7 @@ def _concretize_abstract_shardings(shardings, avals, device_assignment):
21262159

21272160
@lru_cache(maxsize=128)
21282161
def _abstract_to_concrete_mesh(abstract_mesh):
2129-
return mesh_lib.Mesh(
2162+
return Mesh(
21302163
np_dev.reshape(abstract_mesh.axis_sizes), abstract_mesh.axis_names,
21312164
axis_types=abstract_mesh.axis_types)
21322165

@@ -2153,7 +2186,7 @@ def lower_sharding_computation(
21532186
donated_invars: Sequence[bool],
21542187
*,
21552188
keep_unused: bool,
2156-
context_mesh: mesh_lib.Mesh | None,
2189+
context_mesh: Mesh | None,
21572190
compiler_options_kvs: tuple[tuple[str, Any], ...],
21582191
lowering_platforms: tuple[str, ...] | None,
21592192
lowering_parameters: mlir.LoweringParameters,
@@ -2211,6 +2244,7 @@ def lower_sharding_computation(
22112244
((js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info)
22122245
for js, source_info in unique_intermediate_shardings)),
22132246
devices_from_context)
2247+
unique_intermediate_shardings = [js for js, _ in unique_intermediate_shardings]
22142248

22152249
if config.sharding_in_types.value:
22162250
out_shardings = _concretize_abstract_shardings(
@@ -2221,21 +2255,31 @@ def lower_sharding_computation(
22212255
platforms = lowering_platforms or (
22222256
getattr(backend, "_raw_platform", backend.platform),)
22232257

2258+
prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr)
2259+
2260+
# TODO(yashkatariya): All device specific logic should go in compilation
2261+
# but this requires a big refactor. The current `_get_num_devices` logic
2262+
# is good enough to lower with AbstractMesh but cannot be compiled. Once
2263+
# I refactor, this will also work well with mesh being provided at
2264+
# compile time.
2265+
num_devices = _get_num_devices(
2266+
it.chain(unique_in_shardings, unique_out_shardings,
2267+
unique_intermediate_shardings),
2268+
device_assignment, lowering_platforms, prim_requires_devices)
2269+
22242270
committed = bool(
2225-
devices_from_context or
2226-
len(device_assignment) > 1 or
2227-
any(not isinstance(i, UnspecifiedValue) for i in unique_in_shardings) or
2228-
any(not isinstance(js, UnspecifiedValue) for js, _ in unique_intermediate_shardings) or
2229-
any(not isinstance(o, UnspecifiedValue) for o in unique_out_shardings))
2271+
devices_from_context
2272+
or num_devices > 1
2273+
or any(not isinstance(s, UnspecifiedValue) for s in it.chain(
2274+
unique_in_shardings, unique_out_shardings, unique_intermediate_shardings)))
22302275

22312276
da_object = _create_da_object(tuple(device_assignment))
22322277

22332278
transfer_mem_kind_in_jaxpr = jaxpr_transfer_mem_kinds(jaxpr)
22342279
all_default_mem_kind = are_all_shardings_default_mem_kind(
22352280
da_object,
22362281
it.chain(unique_in_shardings, unique_out_shardings,
2237-
[js for js, _ in unique_intermediate_shardings],
2238-
transfer_mem_kind_in_jaxpr)) # pytype: disable=wrong-arg-types
2282+
unique_intermediate_shardings, transfer_mem_kind_in_jaxpr)) # pytype: disable=wrong-arg-types
22392283

22402284
if all_default_mem_kind:
22412285
propagated_out_mem_kinds = (None,) * len(global_out_avals)
@@ -2244,12 +2288,11 @@ def lower_sharding_computation(
22442288
closed_jaxpr, in_shardings)
22452289

22462290
# 2. Build up the HLO
2247-
prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr)
22482291

22492292
abstract_mesh = None
22502293
if prim_requires_devices:
22512294
for sharding in it.chain(unique_in_shardings, unique_out_shardings,
2252-
[js for js, _ in unique_intermediate_shardings]):
2295+
unique_intermediate_shardings):
22532296
if isinstance(sharding, NamedSharding):
22542297
if (abstract_mesh is not None and
22552298
abstract_mesh != sharding.mesh.abstract_mesh):
@@ -2267,7 +2310,7 @@ def lower_sharding_computation(
22672310
(module, keepalive, host_callbacks, unordered_effects, ordered_effects,
22682311
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
22692312
closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,
2270-
semantic_out_shardings, in_layouts, out_layouts, len(da_object),
2313+
semantic_out_shardings, in_layouts, out_layouts, num_devices,
22712314
tuple(da_object) if prim_requires_devices else None, donated_invars,
22722315
name_stack, all_default_mem_kind, inout_aliases,
22732316
propagated_out_mem_kinds, platforms,
@@ -2310,7 +2353,7 @@ def lower_sharding_computation(
23102353
all_default_mem_kind=all_default_mem_kind,
23112354
all_args_info=all_args_info,
23122355
pgle_profiler=pgle_profiler,
2313-
intermediate_shardings=[s for s, _ in unique_intermediate_shardings],
2356+
intermediate_shardings=unique_intermediate_shardings,
23142357
context_mesh=context_mesh)
23152358

23162359

@@ -2480,7 +2523,7 @@ def _register_out_sharding_handler(
24802523

24812524
def _gspmd_to_named_sharding(
24822525
out_s: GSPMDSharding, orig_in_s: NamedSharding) -> NamedSharding:
2483-
assert isinstance(orig_in_s.mesh, mesh_lib.Mesh)
2526+
assert isinstance(orig_in_s.mesh, Mesh)
24842527
return sharding_impls._gspmd_to_named_sharding_via_mesh(out_s, orig_in_s.mesh)
24852528

24862529
_register_out_sharding_handler(NamedSharding, _gspmd_to_named_sharding)
@@ -2532,7 +2575,7 @@ def _get_out_sharding_from_orig_sharding(
25322575

25332576
def maybe_recover_user_shardings(
25342577
old_shardings, new_shardings, old_avals, new_avals,
2535-
intermediate_shardings=None, context_mesh: mesh_lib.Mesh | None = None):
2578+
intermediate_shardings=None, context_mesh: Mesh | None = None):
25362579
if all(not isinstance(o, sharding_impls.GSPMDSharding) for o in new_shardings):
25372580
return new_shardings
25382581

@@ -2832,8 +2875,14 @@ def from_hlo(name: str,
28322875
all_args_info: AllArgsInfo | None = None,
28332876
pgle_profiler: profiler.PGLEProfiler | None = None,
28342877
intermediate_shardings: Sequence[JSharding] | None = None,
2835-
context_mesh: mesh_lib.Mesh | None = None
2878+
context_mesh: Mesh | None = None,
28362879
) -> MeshExecutable:
2880+
if any(isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh)
2881+
for s in it.chain(in_shardings, out_shardings)):
2882+
raise RuntimeError(
2883+
"A jitted computation cannot contain AbstractMesh in in_shardings and"
2884+
" out_shardings during compilation. You can use `jax.export` to "
2885+
" lower with an AbstractMesh and later compile with concrete devices.")
28372886
if shape_poly_state is not None and shape_poly_state.uses_dim_vars:
28382887
hlo = mlir.refine_polymorphic_shapes(hlo)
28392888
if isinstance(device_assignment, xc.DeviceList):

jax/_src/pjit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ def trace(*args, **kwargs) -> stages.Traced:
498498
donate_argnums = tuple(i for i, d in enumerate(p.donated_invars) if d)
499499
args_info = stages.make_args_info(p.in_tree, p.in_avals, donate_argnums)
500500
lower_callable = partial(_resolve_and_lower, args_flat, **p.params,
501-
pgle_profiler=None)
501+
pgle_profiler=None)
502502
return stages.Traced(
503503
p.params['jaxpr'], args_info, p.params["name"], p.out_tree,
504504
lower_callable, p.abstract_mesh, args_flat, p.arg_names, p.num_consts)

tests/pjit_test.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4631,6 +4631,54 @@ def f(x):
46314631
ins, _ = f.lower(np.arange(8)).compile().input_shardings
46324632
self.assertEqual(ins[0], SingleDeviceSharding(jax.devices()[0]))
46334633

4634+
def test_abstract_mesh_lower(self):
4635+
mesh = jtu.create_mesh((2,), 'x')
4636+
mesh2 = jtu.create_mesh((1,), 'x')
4637+
4638+
abstract_sds = jax.ShapeDtypeStruct(
4639+
(8, 2), jnp.float32, sharding=NamedSharding(mesh.abstract_mesh, P('x')))
4640+
abstract_sds2 = jax.ShapeDtypeStruct(
4641+
(8, 2), jnp.float32, sharding=NamedSharding(mesh2.abstract_mesh, P('x')))
4642+
4643+
@jax.jit
4644+
def f(x):
4645+
return x * 2
4646+
4647+
lowered = f.trace(abstract_sds).lower(lowering_platforms=('tpu',))
4648+
self.assertIn('num_partitions = 2', lowered.as_text())
4649+
4650+
with self.assertRaisesRegex(
4651+
RuntimeError, 'A jitted computation cannot contain AbstractMesh'):
4652+
lowered.compile()
4653+
4654+
@jax.jit
4655+
def g(x, y):
4656+
return x, y
4657+
4658+
concrete_s = NamedSharding(mesh, P('x'))
4659+
concrete_sds = jax.ShapeDtypeStruct((8,), jnp.float32, sharding=concrete_s)
4660+
with self.assertRaisesRegex(
4661+
ValueError,
4662+
'AbstractMesh size: 1 does not match the device assignment size: 2'):
4663+
g.lower(abstract_sds2, concrete_sds)
4664+
4665+
with self.assertRaisesRegex(
4666+
ValueError, "Passing lowering_platforms.*is required"):
4667+
g.lower(abstract_sds, np.arange(8))
4668+
4669+
lowered2 = g.trace(abstract_sds, np.arange(8)).lower(
4670+
lowering_platforms=('tpu',))
4671+
self.assertIn('num_partitions = 2', lowered2.as_text())
4672+
with self.assertRaisesRegex(
4673+
RuntimeError, 'A jitted computation cannot contain AbstractMesh'):
4674+
lowered2.compile()
4675+
4676+
lowered3 = g.lower(abstract_sds, concrete_sds)
4677+
self.assertIn('num_partitions = 2', lowered3.as_text())
4678+
with self.assertRaisesRegex(
4679+
RuntimeError, 'A jitted computation cannot contain AbstractMesh'):
4680+
lowered3.compile()
4681+
46344682

46354683
def spec_regex(s):
46364684
return str(s).replace(r"(", r"\(").replace(r")", r"\)")

0 commit comments

Comments
 (0)