Skip to content

Commit b33aca6

Browse files
committed
[export] Create the jax.export module APIs.
The functionality comes from the jax.experimental.export module, which will be deprecated. The following APIs are introduced: ``` from jax import export def f(...): ... ex: export.Exported = export.export(jax.jit(f))(*args, **kwargs) blob: bytearray = ex.serialize() rehydrated: export.Export = export.deserialize(blob) def caller(...): ... rehydrated.call(*args, **kwargs) ``` Module documentation will follow shortly. There are no changes for now in the jax.experimental.export APIs. Most of the changes in this PR are in tests due to some differences in the new jax.export APIs compared to jax.experimental.export: * Instead of `jax.experimental.export.call(exp)` we now write `exp.call` * The `jax.experimental.export.export` allowed the function argument to be any Python callable and it would wrap it with a `jax.jit`. This is not supported anymore by export, and instead the user must use `jax.jit`.
1 parent 14d87d3 commit b33aca6

File tree

15 files changed

+342
-159
lines changed

15 files changed

+342
-159
lines changed

benchmarks/shape_poly_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import jax
1919
from jax import core
2020
from jax._src.numpy import lax_numpy
21-
from jax.experimental import export
21+
from jax import export
2222

2323
jax.config.parse_flags_with_absl()
2424

jax/_src/export/_export.py

Lines changed: 131 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -333,10 +333,10 @@ def in_shardings_jax(
333333
`jax.device_put`.
334334
335335
Example usage:
336-
>>> from jax.experimental import export
336+
>>> from jax import export
337337
>>> exp_mesh = sharding.Mesh(jax.devices(), ("a",))
338338
>>> exp = export.export(jax.jit(lambda x: jax.numpy.add(x, x),
339-
... in_shardings=sharding.NamedSharding(exp_mesh, sharding.PartitionSpec("a")))
339+
... in_shardings=sharding.NamedSharding(exp_mesh, sharding.PartitionSpec("a")))
340340
... )(np.arange(jax.device_count()))
341341
>>> exp.in_shardings_hlo
342342
({devices=[8]<=[8]},)
@@ -347,7 +347,7 @@ def in_shardings_jax(
347347
# Put the args and kwargs on the appropriate devices
348348
>>> run_arg = jax.device_put(np.arange(jax.device_count()),
349349
... exp.in_shardings_jax(run_mesh)[0])
350-
>>> res = export.call(exp)(run_arg)
350+
>>> res = exp.call(run_arg)
351351
>>> res.addressable_shards
352352
[Shard(device=CpuDevice(id=7), index=(slice(0, 1, None),), replica_id=0, data=[0]),
353353
Shard(device=CpuDevice(id=6), index=(slice(1, 2, None),), replica_id=0, data=[2]),
@@ -372,19 +372,53 @@ def out_shardings_jax(
372372
for s in self.out_shardings_hlo)
373373

374374
def has_vjp(self) -> bool:
375+
"""Returns if this Exported supports VJP."""
375376
return self._get_vjp is not None
376377

377378
def vjp(self) -> Exported:
378379
"""Gets the exported VJP.
379380
380381
Returns None if not available, which can happen if the Exported has been
381-
loaded from an external format, without a VJP."""
382+
loaded from an external format without a VJP.
383+
"""
382384
if self._get_vjp is None:
383385
raise ValueError("No VJP is available")
384386
return self._get_vjp(self)
385387

388+
def serialize(self,
389+
vjp_order: int = 0) -> bytearray:
390+
"""Serializes an Exported.
391+
392+
Args:
393+
vjp_order: The maximum vjp order to include. E.g., the value 2 means that we
394+
serialize the primal functions and two orders of the `vjp` function. This
395+
should allow 2nd order reverse mode differentiation of the deserialized
396+
function. i.e., `jax.grad(jax.grad(f)).`
397+
"""
398+
# Lazy load the serialization module, since flatbuffers is an optional
399+
# dependency.
400+
from jax._src.export.serialization import serialize
401+
return serialize(self, vjp_order=vjp_order)
402+
403+
def call(self, *args, **kwargs):
404+
return call_exported(self)(*args, **kwargs)
405+
406+
407+
def deserialize(blob: bytearray) -> Exported:
408+
"""Deserializes an Exported.
409+
410+
Args:
411+
blob: a bytearray obtained from `Exported.serialize`.
412+
"""
413+
# Lazy load the serialization module, since flatbuffers is an optional
414+
# dependency.
415+
from jax._src.export.serialization import deserialize
416+
return deserialize(blob)
417+
386418

387419
def default_lowering_platform() -> str:
420+
"""Retrieves the default lowering platform for the exporting machine.
421+
"""
388422
# Canonicalize to turn 'gpu' into 'cuda' or 'rocm'
389423
return xb.canonicalize_platform(jax.default_backend())
390424

@@ -411,14 +445,20 @@ def args_specs(
411445
return shape_poly.symbolic_args_specs(args, polymorphic_shapes)
412446

413447

414-
def export(fun_jax: Callable,
415-
*,
416-
lowering_platforms: Sequence[str] | None = None,
417-
disabled_checks: Sequence[DisabledSafetyCheck] = (),
418-
_device_assignment_for_internal_jax2tf_use_only = None,
419-
) -> Callable[..., Exported]:
448+
# TODO(necula): remove this once we remove jax.experimental.export.
449+
def export_back_compat(
450+
fun_jax: Callable,
451+
*,
452+
lowering_platforms: Sequence[str] | None = None,
453+
disabled_checks: Sequence[DisabledSafetyCheck] = (),
454+
_device_assignment_for_internal_jax2tf_use_only = None,
455+
) -> Callable[..., Exported]:
420456
"""Exports native serialization for a JAX function.
421457
458+
Note: this function exists only for internal usage by jax2tf and for
459+
backwards compatibility with jax.experimental.export. Use
460+
`jax.export` instead.
461+
422462
Args:
423463
fun_jax: the function to lower and serialize.
424464
lowering_platforms:
@@ -498,6 +538,85 @@ def do_export(*args_specs, **kwargs_specs) -> Exported:
498538
_device_assignment_for_internal_jax2tf_use_only=_device_assignment_for_internal_jax2tf_use_only)
499539
return do_export
500540

541+
def export(
542+
fun_jit: stages.Wrapped,
543+
*,
544+
lowering_platforms: Sequence[str] | None = None,
545+
disabled_checks: Sequence[DisabledSafetyCheck] = (),
546+
) -> Callable[..., Exported]:
547+
"""Exports a JAX function for persistent serialization.
548+
549+
Args:
550+
fun_jit: the function to export. Should be the result of `jit`.
551+
lowering_platforms:
552+
Optional sequence containing a subset of 'tpu', 'cpu',
553+
'cuda', 'rocm'. If more than one platform is specified, then
554+
the lowered code takes an argument specifying the platform.
555+
If None, then use the default JAX backend.
556+
The calling convention for multiple platforms is explained in the
557+
`jax_export.Exported` docstring.
558+
disabled_checks: the safety checks to disable. See docstring
559+
of `DisabledSafetyCheck`.
560+
561+
Returns: a function that takes args and kwargs pytrees of jax.ShapeDtypeStruct,
562+
or values with `.shape` and `.dtype` attributes, and returns an
563+
`Exported`.
564+
565+
Usage:
566+
>>> from jax import export
567+
>>> exported: export.Exported = export.export(jnp.sin)(
568+
... np.arange(4, dtype=np.float32))
569+
570+
# You can inspect the Exported object
571+
>>> exported.in_avals
572+
(ShapedArray(float32[4]),)
573+
>>> blob: bytearray = exported.serialize()
574+
575+
# The serialized bytes are safe to use in a separate process
576+
>>> rehydrated: export.Exported = export.deserialize(blob)
577+
>>> rehydrated.fun_name
578+
'sin'
579+
>>> rehydrated.call(np.array([.1, .2, .3, .4], dtype=np.float32))
580+
Array([0.09983342, 0.19866933, 0.29552022, 0.38941833], dtype=float32)
581+
"""
582+
if not isinstance(fun_jit, stages.Wrapped):
583+
raise ValueError(
584+
f"Function to be exported must be the result of `jit` but is: {fun_jit}")
585+
if lowering_platforms is not None:
586+
actual_lowering_platforms = tuple(lowering_platforms)
587+
else:
588+
actual_lowering_platforms = (default_lowering_platform(),)
589+
590+
def do_export(*args_specs, **kwargs_specs) -> Exported:
591+
# TODO: move to `lower`
592+
symbolic_scope: tuple[shape_poly.SymbolicScope, tree_util.KeyPath] | None = None # type: ignore[invalid-annotation,unused-ignore]
593+
for k_path, aval in tree_util.tree_flatten_with_path((args_specs, kwargs_specs))[0]:
594+
# Static args may have no `shape` attribute.
595+
if not hasattr(aval, "shape"):
596+
continue
597+
for d in aval.shape:
598+
if shape_poly.is_symbolic_dim(d):
599+
if symbolic_scope is None:
600+
symbolic_scope = (d.scope, k_path)
601+
continue
602+
symbolic_scope[0]._check_same_scope(
603+
d, when=f"when exporting {util.fun_name(fun_jit)}",
604+
self_descr=f"current (from {shape_poly.args_kwargs_path_to_str(symbolic_scope[1])}) ",
605+
other_descr=shape_poly.args_kwargs_path_to_str(k_path))
606+
607+
traced = fun_jit.trace( # type: ignore
608+
*args_specs, **kwargs_specs,
609+
_experimental_lowering_parameters=mlir.LoweringParameters(
610+
platforms=actual_lowering_platforms,
611+
for_export=True,
612+
))
613+
jaxpr, fun_name = traced.jaxpr, traced.fun_name
614+
lowered = traced.lower()
615+
return _export_lowered(
616+
lowered, jaxpr, fun_name,
617+
disabled_checks=disabled_checks)
618+
return do_export
619+
501620
def _export_lowered(
502621
lowered: stages.Lowered,
503622
jaxpr: core.ClosedJaxpr, fun_name: str,
@@ -599,7 +718,7 @@ def _get_exported_vjp(exp_primal: Exported) -> Exported:
599718
device_assignment=device_assignment,
600719
apply_jit=True,
601720
flat_primal_fun=True)
602-
return export(fun_vjp_jax,
721+
return export(fun_vjp_jax, # type: ignore[arg-type]
603722
lowering_platforms=exp_primal.lowering_platforms,
604723
disabled_checks=exp_primal.disabled_safety_checks)(*vjp_in_avals)
605724

@@ -816,7 +935,7 @@ def is_token(typ, attrs):
816935

817936
def _check_lowering(lowering) -> None:
818937
if not isinstance(lowering, pxla.MeshComputation):
819-
raise NotImplementedError(f"serialization is supported only for pjit. {lowering}")
938+
raise NotImplementedError(f"serialization is supported only for jit. {lowering}")
820939

821940
if lowering.compile_args["host_callbacks"] or lowering.compile_args["keepalive"]:
822941
raise NotImplementedError("serialization of host_callbacks is not yet implemented")

jax/_src/export/serialization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
_SERIALIZATION_VERSION = 2
4949

5050
def serialize(exp: _export.Exported, vjp_order: int = 0) -> bytearray:
51-
"""Serialize an Exported.
51+
"""Serializes an Exported.
5252
5353
Args:
5454
exp: the Exported to serialize.
@@ -64,7 +64,7 @@ def serialize(exp: _export.Exported, vjp_order: int = 0) -> bytearray:
6464

6565

6666
def deserialize(ser: bytearray) -> _export.Exported:
67-
"""Deserialize an Exported."""
67+
"""Deserializes an Exported."""
6868
exp = ser_flatbuf.Exported.GetRootAsExported(ser)
6969
return _deserialize_exported(exp)
7070

jax/_src/internal_test_util/export_back_compat_test_util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def func(...): ...
8686

8787
import jax
8888
from jax import tree_util
89-
from jax.experimental import export
89+
from jax import export
9090

9191
from jax.experimental import pjit
9292

@@ -345,4 +345,4 @@ def _get_vjp(_):
345345
_get_vjp=_get_vjp)
346346

347347
# We use pjit in case there are shardings in the exported module.
348-
return pjit.pjit(export.call(exported))(*data.inputs)
348+
return pjit.pjit(exported.call)(*data.inputs)

jax/_src/stages.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
from collections.abc import Sequence
3434
from dataclasses import dataclass
35-
from typing import Any, NamedTuple, Protocol, Union
35+
from typing import Any, NamedTuple, Protocol, Union, runtime_checkable
3636
import warnings
3737

3838
import jax
@@ -756,8 +756,9 @@ def cost_analysis(self) -> Any | None:
756756
return None
757757

758758

759+
@runtime_checkable
759760
class Wrapped(Protocol):
760-
"""A function ready to be specialized, lowered, and compiled.
761+
"""A function ready to be traced, lowered, and compiled.
761762
762763
This protocol reflects the output of functions such as
763764
``jax.jit``. Calling it results in JIT (just-in-time) lowering,

jax/experimental/export/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@
1717
minimum_supported_serialization_version,
1818
maximum_supported_serialization_version,
1919
Exported,
20-
export,
2120
call_exported, # TODO: deprecate
2221
call,
2322
DisabledSafetyCheck,
24-
default_lowering_platform,
23+
default_lowering_platform, # TODO: deprecate
2524
)
25+
from jax._src.export._export import export_back_compat as export
26+
2627
from jax._src.export.shape_poly import (
2728
is_symbolic_dim,
2829
symbolic_shape,
@@ -33,4 +34,6 @@
3334
serialize,
3435
deserialize,
3536
)
37+
# Import only to set the shape poly decision procedure
3638
from jax._src.export import shape_poly_decision
39+
del shape_poly_decision

jax/experimental/jax2tf/jax2tf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from jax import numpy as jnp
3737
from jax import tree_util
3838
from jax import sharding
39-
from jax.experimental import export
39+
from jax import export
4040
from jax.experimental.jax2tf import impl_no_xla
4141
from jax.interpreters import xla
4242

@@ -515,7 +515,7 @@ def _restore_context():
515515

516516
self._restore_context = _restore_context
517517
_exported_device_assignment = [None]
518-
self.exported = export.export(
518+
self.exported = _export.export_back_compat(
519519
self.fun_jax,
520520
lowering_platforms=self.native_serialization_platforms,
521521
disabled_checks=self.native_serialization_disabled_checks,

jax/experimental/jax2tf/tests/call_tf_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@
2525
import jax
2626
from jax import dlpack
2727
from jax import dtypes
28+
from jax import export
2829
from jax import lax
2930
from jax import numpy as jnp
3031
from jax._src import config
3132
from jax._src import test_util as jtu
3233
from jax._src.lib.mlir import ir
3334
from jax._src.lib.mlir.dialects import hlo
34-
from jax.experimental import export
3535
from jax.experimental import jax2tf
3636
from jax.experimental.jax2tf.tests import tf_test_util
3737
import numpy as np
@@ -778,7 +778,7 @@ def f_jax(x):
778778

779779
lowering_platforms = ("tpu", "cpu", "cuda")
780780

781-
exp = export.export(f_jax,
781+
exp = export.export(jax.jit(f_jax),
782782
lowering_platforms=lowering_platforms)(x)
783783
for jax_platform in jax_and_tf_platforms:
784784
with self.subTest(jax_platform):
@@ -787,7 +787,7 @@ def f_jax(x):
787787
logging.info("Running harness natively on %s", jax_device)
788788
native_res = f_jax(x_device)
789789
logging.info("Running exported harness on %s", jax_device)
790-
exported_res = export.call(exp)(x_device)
790+
exported_res = exp.call(x_device)
791791
self.assertAllClose(native_res, exported_res)
792792

793793
def test_multi_platform_call_tf_graph(self):

jax/experimental/jax2tf/tests/jax2tf_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import jax
2828
from jax import ad_checkpoint
2929
from jax import dtypes
30+
from jax import export
3031
from jax import lax
3132
from jax import numpy as jnp
3233
from jax import sharding
@@ -37,7 +38,6 @@
3738
from jax._src import test_util as jtu
3839
from jax._src import xla_bridge as xb
3940
from jax.experimental import jax2tf
40-
from jax.experimental import export
4141
from jax.experimental.jax2tf.tests import tf_test_util
4242
from jax.experimental.shard_map import shard_map
4343
from jax.experimental import pjit
@@ -1559,7 +1559,7 @@ def apply_transform(func, transform: str):
15591559
# Run the JAX native version, to check it works, and to fill caches.
15601560
_ = func_to_convert(*args)
15611561
exported = export.export(
1562-
func_to_convert,
1562+
(jax.jit(func_to_convert) if not hasattr(func_to_convert, "trace") else func_to_convert),
15631563
lowering_platforms=("tpu",)
15641564
)(*(core.ShapedArray(a.shape, a.dtype) for a in args))
15651565

jax/experimental/jax2tf/tests/shape_poly_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@
3232

3333
import jax
3434
from jax.experimental import jax2tf
35-
from jax.experimental import export
3635
from jax.experimental import pjit
36+
from jax import export
3737
from jax import lax
3838
import jax.numpy as jnp
3939
from jax import random

0 commit comments

Comments
 (0)