Skip to content

Commit 2b06f93

Browse files
Merge pull request #25435 from gnecula:export_abs
PiperOrigin-RevId: 706712699
2 parents 3b9a8f7 + afcb62e commit 2b06f93

File tree

6 files changed

+109
-43
lines changed

6 files changed

+109
-43
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
2828
* from {mod}`jax.lib.xla_client`: `_xla` and `bfloat16`.
2929
* from {mod}`jax.numpy`: `round_`.
3030

31+
* New Features
32+
* {func}`jax.export.export` can be used for device-polymorphic export with
33+
shardings constructed with {func}`jax.sharding.AbstractMesh`.
34+
See the [jax.export documentation](https://jax.readthedocs.io/en/latest/export/export.html#device-polymorphic-export).
35+
3136
## jax 0.4.37 (Dec 9, 2024)
3237

3338
This is a patch release of jax 0.4.36. Only "jax" was released at this version.

docs/export/export.md

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ present on the exporting machine:
240240

241241
```
242242

243-
There is a safety check that will be raise an error when trying to compile
243+
There is a safety check that will raise an error when trying to compile
244244
an `Exported` object on a machine that does not have the accelerator
245245
for which the code was exported.
246246

@@ -326,7 +326,7 @@ combinations of input shapes.
326326

327327
See the {ref}`shape_poly` documentation.
328328

329-
## Device polymorphic export
329+
## Device-polymorphic export
330330

331331
An exported artifact may contain sharding annotations for inputs,
332332
outputs and for some intermediates, but these annotations do not refer
@@ -335,20 +335,28 @@ Instead, the sharding annotations refer to logical devices. This
335335
means that you can compile and run the exported artifacts on different
336336
physical devices that were used for exporting.
337337

338+
The cleanest way to achieve a device-polymorphic export is to
339+
use shardings constructed with a `jax.sharding.AbstractMesh`,
340+
which contains only the mesh shape and axis names. But,
341+
you can achieve the same results if you use shardings
342+
constructed for a mesh with concrete devices, since the actual
343+
devices in the mesh are ignored for tracing and lowering:
344+
338345
```python
339346
>>> import jax
340347
>>> from jax import export
341-
>>> from jax.sharding import Mesh, NamedSharding
348+
>>> from jax.sharding import AbstractMesh, Mesh, NamedSharding
342349
>>> from jax.sharding import PartitionSpec as P
350+
>>>
351+
>>> # Use an AbstractMesh for exporting
352+
>>> export_mesh = AbstractMesh((("a", 4),))
343353

344-
>>> # Use the first 4 devices for exporting.
345-
>>> export_devices = jax.local_devices()[:4]
346-
>>> export_mesh = Mesh(export_devices, ("a",))
347354
>>> def f(x):
348355
... return x.T
349356

350-
>>> arg = jnp.arange(8 * len(export_devices))
351-
>>> exp = export.export(jax.jit(f, in_shardings=(NamedSharding(export_mesh, P("a")),)))(arg)
357+
>>> exp = export.export(jax.jit(f))(
358+
... jax.ShapeDtypeStruct((32,), dtype=np.int32,
359+
... sharding=NamedSharding(export_mesh, P("a"))))
352360

353361
>>> # `exp` knows for how many devices it was exported.
354362
>>> exp.nr_devices
@@ -359,8 +367,20 @@ physical devices that were used for exporting.
359367
>>> exp.in_shardings_hlo
360368
({devices=[4]<=[4]},)
361369

370+
>>> # You can also use a concrete set of devices for exporting
371+
>>> concrete_devices = jax.local_devices()[:4]
372+
>>> concrete_mesh = Mesh(concrete_devices, ("a",))
373+
>>> exp2 = export.export(jax.jit(f))(
374+
... jax.ShapeDtypeStruct((32,), dtype=np.int32,
375+
... sharding=NamedSharding(concrete_mesh, P("a"))))
376+
377+
>>> # You can expect the same results
378+
>>> assert exp.in_shardings_hlo == exp2.in_shardings_hlo
379+
380+
>>> # When you call an Exported, you must use a concrete set of devices
381+
>>> arg = jnp.arange(8 * 4)
362382
>>> res1 = exp.call(jax.device_put(arg,
363-
... NamedSharding(export_mesh, P("a"))))
383+
... NamedSharding(concrete_mesh, P("a"))))
364384

365385
>>> # Check out the first 2 shards of the result
366386
>>> [f"device={s.device} index={s.index}" for s in res1.addressable_shards[:2]]
@@ -397,9 +417,11 @@ of devices than it was exported for:
397417
>>> def f(x):
398418
... return x.T
399419

400-
>>> arg = jnp.arange(4 * len(export_devices))
401-
>>> exp = export.export(jax.jit(f, in_shardings=(NamedSharding(export_mesh, P("a")),)))(arg)
420+
>>> exp = export.export(jax.jit(f))(
421+
... jax.ShapeDtypeStruct((4 * len(export_devices),), dtype=np.int32,
422+
... sharding=NamedSharding(export_mesh, P("a"))))
402423

424+
>>> arg = jnp.arange(4 * len(export_devices))
403425
>>> exp.call(arg) # doctest: +IGNORE_EXCEPTION_DETAIL
404426
Traceback (most recent call last):
405427
ValueError: Exported module f was lowered for 8 devices and is called in a context with 1 devices. This is disallowed because: the module was lowered for more than 1 device.
@@ -420,13 +442,16 @@ artifacts using a new mesh constructed at the call site:
420442
>>> def f(x):
421443
... return x.T
422444

423-
>>> arg = jnp.arange(4 * len(export_devices))
424-
>>> exp = export.export(jax.jit(f, in_shardings=(NamedSharding(export_mesh, P("a")),)))(arg)
445+
446+
>>> exp = export.export(jax.jit(f))(
447+
... jax.ShapeDtypeStruct((4 * len(export_devices),), dtype=np.int32,
448+
... sharding=NamedSharding(export_mesh, P("a"))))
425449

426450
>>> # Prepare the mesh for calling `exp`.
427451
>>> calling_mesh = Mesh(np.array(export_devices[::-1]), ("b",))
428452

429453
>>> # Shard the arg according to what `exp` expects.
454+
>>> arg = jnp.arange(4 * len(export_devices))
430455
>>> sharded_arg = jax.device_put(arg, exp.in_shardings_jax(calling_mesh)[0])
431456
>>> res = exp.call(sharded_arg)
432457

jax/_src/export/_export.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,7 @@ def _export_lowered(
633633
jaxpr: core.ClosedJaxpr,
634634
fun_name: str,
635635
disabled_checks: Sequence[DisabledSafetyCheck] = (),
636-
_device_assignment_for_internal_jax2tf_use_only = None,
636+
_device_assignment_for_internal_jax2tf_use_only=None,
637637
) -> Exported:
638638
version = config.jax_export_calling_convention_version.value
639639
if (version < minimum_supported_calling_convention_version or
@@ -698,7 +698,7 @@ def _export_lowered(
698698
ordered_effects = tuple(lowering.compile_args["ordered_effects"])
699699
unordered_effects = tuple(lowering.compile_args["unordered_effects"])
700700

701-
nr_devices = len(lowering.compile_args["device_assignment"])
701+
nr_devices = lowering.compile_args["num_devices"]
702702
def export_sharding(s: LoweringSharding,
703703
aval: core.ShapedArray) -> HloSharding | None:
704704
if isinstance(s, sharding_impls.UnspecifiedValue):
@@ -971,7 +971,8 @@ def _check_lowering(lowering) -> None:
971971
"keepalive", "host_callbacks", "pmap_nreps", "committed",
972972
"device_assignment", "jaxpr_debug_info", "shape_poly_state",
973973
"all_default_mem_kind", "in_layouts", "out_layouts", "all_args_info",
974-
"pgle_profiler", "intermediate_shardings", "context_mesh"}
974+
"pgle_profiler", "intermediate_shardings", "context_mesh",
975+
"num_devices"}
975976
for compile_arg in lowering.compile_args.keys():
976977
if compile_arg not in allowed_compile_args:
977978
raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]")

jax/_src/interpreters/pxla.py

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1999,7 +1999,7 @@ def jaxpr_transfer_mem_kinds(
19991999
return out
20002000

20012001

2002-
def are_all_shardings_default_mem_kind(da_object, shardings):
2002+
def are_all_shardings_default_mem_kind(da_object: xc.DeviceList, shardings):
20032003
if da_object is None:
20042004
return True
20052005
try:
@@ -2084,38 +2084,32 @@ def write(var, val):
20842084

20852085

20862086
def _get_num_devices(
2087-
shardings, device_assignment, lowering_platforms, prim_requires_devices
2088-
) -> tuple[int, tuple[xc.Device, ...] | None]:
2089-
ext_abstract_mesh, concrete_sharding = None, False
2087+
shardings, device_assignment
2088+
) -> tuple[int, tuple[xc.Device, ...] | None]:
2089+
"""Number of lowering devices, and the device_assignment to use.
2090+
2091+
If all the specified shardings have an abstract mesh, then we are compiling
2092+
with abstract devices, and the returned device_assignment is None.
2093+
"""
2094+
abstract_mesh, any_concrete_sharding = None, False
20902095
for s in shardings:
20912096
if isinstance(s, UnspecifiedValue):
20922097
continue
20932098
elif isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh):
2094-
if ext_abstract_mesh is not None and ext_abstract_mesh != s.mesh:
2099+
if abstract_mesh is not None and abstract_mesh != s.mesh:
20952100
raise ValueError("AbstractMesh should be the same across all "
2096-
f"shardings. Got {ext_abstract_mesh} and {s.mesh}")
2097-
ext_abstract_mesh = s.mesh
2101+
f"shardings. Got {abstract_mesh} and {s.mesh}")
2102+
abstract_mesh = s.mesh
20982103
else:
2099-
concrete_sharding = True
2100-
if (concrete_sharding and ext_abstract_mesh is not None and
2101-
len(device_assignment) != ext_abstract_mesh.size):
2104+
any_concrete_sharding = True
2105+
if (any_concrete_sharding and abstract_mesh is not None and
2106+
len(device_assignment) != abstract_mesh.size):
21022107
raise ValueError(
2103-
f"AbstractMesh size: {ext_abstract_mesh.size} does not match the"
2108+
f"AbstractMesh size: {abstract_mesh.size} does not match the"
21042109
f" device assignment size: {len(device_assignment)}")
2105-
if concrete_sharding:
2110+
if any_concrete_sharding or abstract_mesh is None:
21062111
return len(device_assignment), device_assignment
2107-
if ext_abstract_mesh is None:
2108-
return len(device_assignment), device_assignment
2109-
if lowering_platforms is None:
2110-
raise ValueError(
2111-
"Passing lowering_platforms via"
2112-
" jit(f).trace(*args).lower(lowering_platforms=...) is required when"
2113-
" only AbstractMesh exists in a jitted computation.")
2114-
if prim_requires_devices:
2115-
raise ValueError(
2116-
"AbstractMesh cannot be used when jaxpr contains primitives that"
2117-
" require devices to be present during lowering.")
2118-
return ext_abstract_mesh.size, None
2112+
return abstract_mesh.size, None
21192113

21202114

21212115
MaybeLayout = Sequence[Union[DeviceLocalLayout, AutoLayout, None]]
@@ -2269,7 +2263,17 @@ def lower_sharding_computation(
22692263
num_devices, device_assignment = _get_num_devices( # type: ignore
22702264
it.chain(unique_in_shardings, unique_out_shardings,
22712265
unique_intermediate_shardings),
2272-
device_assignment, lowering_platforms, prim_requires_devices)
2266+
device_assignment)
2267+
if device_assignment is None:
2268+
if lowering_platforms is None:
2269+
raise ValueError(
2270+
"Passing lowering_platforms via jax.export or "
2271+
" jit(f).trace(*args).lower(lowering_platforms=...) is required when"
2272+
" only AbstractMesh exists in a jitted computation.")
2273+
if prim_requires_devices:
2274+
raise ValueError(
2275+
"AbstractMesh cannot be used when jaxpr contains primitives that"
2276+
" require devices to be present during lowering.")
22732277

22742278
committed = bool(
22752279
devices_from_context
@@ -2349,6 +2353,7 @@ def lower_sharding_computation(
23492353
mut=mut,
23502354
backend=backend,
23512355
device_assignment=da_object,
2356+
num_devices=num_devices,
23522357
committed=committed,
23532358
in_layouts=in_layouts,
23542359
out_layouts=out_layouts,
@@ -2874,6 +2879,7 @@ def from_hlo(name: str,
28742879
in_layouts: MaybeLayout,
28752880
out_layouts: MaybeLayout,
28762881
compiler_options_kvs: tuple[tuple[str, Any], ...],
2882+
num_devices: int,
28772883
pmap_nreps: int = 1,
28782884
mut: MutationData | None = None,
28792885
shape_poly_state: mlir.ShapePolyLoweringState | None = None,
@@ -2883,6 +2889,7 @@ def from_hlo(name: str,
28832889
intermediate_shardings: Sequence[JSharding] | None = None,
28842890
context_mesh: Mesh | None = None,
28852891
) -> MeshExecutable:
2892+
del num_devices # For compilation, we have an actual device_assignment
28862893
if (device_assignment is None or
28872894
any(isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh)
28882895
for s in it.chain(in_shardings, out_shardings))):

jax/_src/mesh.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ class AbstractMesh:
368368
It does not contain concrete devices compared to `jax.sharding.Mesh`. You
369369
should use this as an input to the sharding passed to with_sharding_constraint
370370
and mesh passed to shard_map to avoid tracing and lowering cache misses when
371-
your mesh shape and names stay the same but the devices change.
371+
your mesh shape and axis names stay the same but the devices change.
372372
See the description of https://github.com/jax-ml/jax/pull/23022 for more
373373
details.
374374
"""

tests/export_test.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1156,7 +1156,35 @@ def test_input_shardings_unused_args(self):
11561156
self.assertEqual(res.addressable_shards[0].device, run_devices[0])
11571157
self.assertEqual(res.addressable_shards[1].device, run_devices[1])
11581158

1159-
def test_call_with_different_no_of_devices(self):
1159+
def test_export_abstract_mesh(self):
1160+
if jax.local_device_count() < 2:
1161+
self.skipTest("Need at least 2 devices")
1162+
1163+
abs_mesh = jax.sharding.AbstractMesh((("x", 2),))
1164+
input_sharding = jax.sharding.NamedSharding(abs_mesh, P("x", None))
1165+
output_sharding = jax.sharding.NamedSharding(abs_mesh, P(None, "x"))
1166+
@jax.jit
1167+
def f(a):
1168+
b = a @ a.T
1169+
return jax.lax.with_sharding_constraint(b, output_sharding)
1170+
1171+
exp = get_exported(f)(
1172+
jax.ShapeDtypeStruct((16, 16), dtype=np.float32,
1173+
sharding=input_sharding))
1174+
# Call the Exported with a concrete Mesh
1175+
devices = jax.local_devices()[:2]
1176+
run_mesh = Mesh(devices, ("x",))
1177+
a_sharding = jax.sharding.NamedSharding(run_mesh, P("x", None))
1178+
a = jnp.arange(16 * 16, dtype=np.float32).reshape((16, 16))
1179+
a = jax.device_put(a, a_sharding)
1180+
1181+
res = exp.call(a)
1182+
self.assertAllClose(res, f(a))
1183+
self.assertLen(res.addressable_shards, 2)
1184+
self.assertEqual(res.addressable_shards[0].index, (slice(None), slice(0, 8)))
1185+
self.assertEqual(res.addressable_shards[1].index, (slice(None), slice(8, 16)))
1186+
1187+
def test_call_single_device_export_with_different_no_of_devices(self):
11601188
if jax.local_device_count() < 2:
11611189
self.skipTest("Need at least 2 devices")
11621190

0 commit comments

Comments
 (0)