Skip to content

Commit 3fe7377

Browse files
author
jax authors
committed
Merge pull request #21763 from gnecula:export_api
PiperOrigin-RevId: 641959833
2 parents 833c7ba + b33aca6 commit 3fe7377

File tree

15 files changed

+342
-159
lines changed

15 files changed

+342
-159
lines changed

benchmarks/shape_poly_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import jax
1919
from jax import core
2020
from jax._src.numpy import lax_numpy
21-
from jax.experimental import export
21+
from jax import export
2222

2323
jax.config.parse_flags_with_absl()
2424

jax/_src/export/_export.py

Lines changed: 131 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -333,10 +333,10 @@ def in_shardings_jax(
333333
`jax.device_put`.
334334
335335
Example usage:
336-
>>> from jax.experimental import export
336+
>>> from jax import export
337337
>>> exp_mesh = sharding.Mesh(jax.devices(), ("a",))
338338
>>> exp = export.export(jax.jit(lambda x: jax.numpy.add(x, x),
339-
... in_shardings=sharding.NamedSharding(exp_mesh, sharding.PartitionSpec("a")))
339+
... in_shardings=sharding.NamedSharding(exp_mesh, sharding.PartitionSpec("a")))
340340
... )(np.arange(jax.device_count()))
341341
>>> exp.in_shardings_hlo
342342
({devices=[8]<=[8]},)
@@ -347,7 +347,7 @@ def in_shardings_jax(
347347
# Put the args and kwargs on the appropriate devices
348348
>>> run_arg = jax.device_put(np.arange(jax.device_count()),
349349
... exp.in_shardings_jax(run_mesh)[0])
350-
>>> res = export.call(exp)(run_arg)
350+
>>> res = exp.call(run_arg)
351351
>>> res.addressable_shards
352352
[Shard(device=CpuDevice(id=7), index=(slice(0, 1, None),), replica_id=0, data=[0]),
353353
Shard(device=CpuDevice(id=6), index=(slice(1, 2, None),), replica_id=0, data=[2]),
@@ -372,19 +372,53 @@ def out_shardings_jax(
372372
for s in self.out_shardings_hlo)
373373

374374
def has_vjp(self) -> bool:
375+
"""Returns if this Exported supports VJP."""
375376
return self._get_vjp is not None
376377

377378
def vjp(self) -> Exported:
378379
"""Gets the exported VJP.
379380
380381
Returns None if not available, which can happen if the Exported has been
381-
loaded from an external format, without a VJP."""
382+
loaded from an external format without a VJP.
383+
"""
382384
if self._get_vjp is None:
383385
raise ValueError("No VJP is available")
384386
return self._get_vjp(self)
385387

388+
def serialize(self,
389+
vjp_order: int = 0) -> bytearray:
390+
"""Serializes an Exported.
391+
392+
Args:
393+
vjp_order: The maximum vjp order to include. E.g., the value 2 means that we
394+
serialize the primal functions and two orders of the `vjp` function. This
395+
should allow 2nd order reverse mode differentiation of the deserialized
396+
function. i.e., `jax.grad(jax.grad(f)).`
397+
"""
398+
# Lazy load the serialization module, since flatbuffers is an optional
399+
# dependency.
400+
from jax._src.export.serialization import serialize
401+
return serialize(self, vjp_order=vjp_order)
402+
403+
def call(self, *args, **kwargs):
404+
return call_exported(self)(*args, **kwargs)
405+
406+
407+
def deserialize(blob: bytearray) -> Exported:
408+
"""Deserializes an Exported.
409+
410+
Args:
411+
blob: a bytearray obtained from `Exported.serialize`.
412+
"""
413+
# Lazy load the serialization module, since flatbuffers is an optional
414+
# dependency.
415+
from jax._src.export.serialization import deserialize
416+
return deserialize(blob)
417+
386418

387419
def default_lowering_platform() -> str:
420+
"""Retrieves the default lowering platform for the exporting machine.
421+
"""
388422
# Canonicalize to turn 'gpu' into 'cuda' or 'rocm'
389423
return xb.canonicalize_platform(jax.default_backend())
390424

@@ -411,14 +445,20 @@ def args_specs(
411445
return shape_poly.symbolic_args_specs(args, polymorphic_shapes)
412446

413447

414-
def export(fun_jax: Callable,
415-
*,
416-
lowering_platforms: Sequence[str] | None = None,
417-
disabled_checks: Sequence[DisabledSafetyCheck] = (),
418-
_device_assignment_for_internal_jax2tf_use_only = None,
419-
) -> Callable[..., Exported]:
448+
# TODO(necula): remove this once we remove jax.experimental.export.
449+
def export_back_compat(
450+
fun_jax: Callable,
451+
*,
452+
lowering_platforms: Sequence[str] | None = None,
453+
disabled_checks: Sequence[DisabledSafetyCheck] = (),
454+
_device_assignment_for_internal_jax2tf_use_only = None,
455+
) -> Callable[..., Exported]:
420456
"""Exports native serialization for a JAX function.
421457
458+
Note: this function exists only for internal usage by jax2tf and for
459+
backwards compatibility with jax.experimental.export. Use
460+
`jax.export` instead.
461+
422462
Args:
423463
fun_jax: the function to lower and serialize.
424464
lowering_platforms:
@@ -498,6 +538,85 @@ def do_export(*args_specs, **kwargs_specs) -> Exported:
498538
_device_assignment_for_internal_jax2tf_use_only=_device_assignment_for_internal_jax2tf_use_only)
499539
return do_export
500540

541+
def export(
542+
fun_jit: stages.Wrapped,
543+
*,
544+
lowering_platforms: Sequence[str] | None = None,
545+
disabled_checks: Sequence[DisabledSafetyCheck] = (),
546+
) -> Callable[..., Exported]:
547+
"""Exports a JAX function for persistent serialization.
548+
549+
Args:
550+
fun_jit: the function to export. Should be the result of `jit`.
551+
lowering_platforms:
552+
Optional sequence containing a subset of 'tpu', 'cpu',
553+
'cuda', 'rocm'. If more than one platform is specified, then
554+
the lowered code takes an argument specifying the platform.
555+
If None, then use the default JAX backend.
556+
The calling convention for multiple platforms is explained in the
557+
`jax_export.Exported` docstring.
558+
disabled_checks: the safety checks to disable. See docstring
559+
of `DisabledSafetyCheck`.
560+
561+
Returns: a function that takes args and kwargs pytrees of jax.ShapeDtypeStruct,
562+
or values with `.shape` and `.dtype` attributes, and returns an
563+
`Exported`.
564+
565+
Usage:
566+
>>> from jax import export
567+
>>> exported: export.Exported = export.export(jnp.sin)(
568+
... np.arange(4, dtype=np.float32))
569+
570+
# You can inspect the Exported object
571+
>>> exported.in_avals
572+
(ShapedArray(float32[4]),)
573+
>>> blob: bytearray = exported.serialize()
574+
575+
# The serialized bytes are safe to use in a separate process
576+
>>> rehydrated: export.Exported = export.deserialize(blob)
577+
>>> rehydrated.fun_name
578+
'sin'
579+
>>> rehydrated.call(np.array([.1, .2, .3, .4], dtype=np.float32))
580+
Array([0.09983342, 0.19866933, 0.29552022, 0.38941833], dtype=float32)
581+
"""
582+
if not isinstance(fun_jit, stages.Wrapped):
583+
raise ValueError(
584+
f"Function to be exported must be the result of `jit` but is: {fun_jit}")
585+
if lowering_platforms is not None:
586+
actual_lowering_platforms = tuple(lowering_platforms)
587+
else:
588+
actual_lowering_platforms = (default_lowering_platform(),)
589+
590+
def do_export(*args_specs, **kwargs_specs) -> Exported:
591+
# TODO: move to `lower`
592+
symbolic_scope: tuple[shape_poly.SymbolicScope, tree_util.KeyPath] | None = None # type: ignore[invalid-annotation,unused-ignore]
593+
for k_path, aval in tree_util.tree_flatten_with_path((args_specs, kwargs_specs))[0]:
594+
# Static args may have no `shape` attribute.
595+
if not hasattr(aval, "shape"):
596+
continue
597+
for d in aval.shape:
598+
if shape_poly.is_symbolic_dim(d):
599+
if symbolic_scope is None:
600+
symbolic_scope = (d.scope, k_path)
601+
continue
602+
symbolic_scope[0]._check_same_scope(
603+
d, when=f"when exporting {util.fun_name(fun_jit)}",
604+
self_descr=f"current (from {shape_poly.args_kwargs_path_to_str(symbolic_scope[1])}) ",
605+
other_descr=shape_poly.args_kwargs_path_to_str(k_path))
606+
607+
traced = fun_jit.trace( # type: ignore
608+
*args_specs, **kwargs_specs,
609+
_experimental_lowering_parameters=mlir.LoweringParameters(
610+
platforms=actual_lowering_platforms,
611+
for_export=True,
612+
))
613+
jaxpr, fun_name = traced.jaxpr, traced.fun_name
614+
lowered = traced.lower()
615+
return _export_lowered(
616+
lowered, jaxpr, fun_name,
617+
disabled_checks=disabled_checks)
618+
return do_export
619+
501620
def _export_lowered(
502621
lowered: stages.Lowered,
503622
jaxpr: core.ClosedJaxpr, fun_name: str,
@@ -599,7 +718,7 @@ def _get_exported_vjp(exp_primal: Exported) -> Exported:
599718
device_assignment=device_assignment,
600719
apply_jit=True,
601720
flat_primal_fun=True)
602-
return export(fun_vjp_jax,
721+
return export(fun_vjp_jax, # type: ignore[arg-type]
603722
lowering_platforms=exp_primal.lowering_platforms,
604723
disabled_checks=exp_primal.disabled_safety_checks)(*vjp_in_avals)
605724

@@ -816,7 +935,7 @@ def is_token(typ, attrs):
816935

817936
def _check_lowering(lowering) -> None:
818937
if not isinstance(lowering, pxla.MeshComputation):
819-
raise NotImplementedError(f"serialization is supported only for pjit. {lowering}")
938+
raise NotImplementedError(f"serialization is supported only for jit. {lowering}")
820939

821940
if lowering.compile_args["host_callbacks"] or lowering.compile_args["keepalive"]:
822941
raise NotImplementedError("serialization of host_callbacks is not yet implemented")

jax/_src/export/serialization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
_SERIALIZATION_VERSION = 2
4949

5050
def serialize(exp: _export.Exported, vjp_order: int = 0) -> bytearray:
51-
"""Serialize an Exported.
51+
"""Serializes an Exported.
5252
5353
Args:
5454
exp: the Exported to serialize.
@@ -64,7 +64,7 @@ def serialize(exp: _export.Exported, vjp_order: int = 0) -> bytearray:
6464

6565

6666
def deserialize(ser: bytearray) -> _export.Exported:
67-
"""Deserialize an Exported."""
67+
"""Deserializes an Exported."""
6868
exp = ser_flatbuf.Exported.GetRootAsExported(ser)
6969
return _deserialize_exported(exp)
7070

jax/_src/internal_test_util/export_back_compat_test_util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def func(...): ...
8686

8787
import jax
8888
from jax import tree_util
89-
from jax.experimental import export
89+
from jax import export
9090

9191
from jax.experimental import pjit
9292

@@ -345,4 +345,4 @@ def _get_vjp(_):
345345
_get_vjp=_get_vjp)
346346

347347
# We use pjit in case there are shardings in the exported module.
348-
return pjit.pjit(export.call(exported))(*data.inputs)
348+
return pjit.pjit(exported.call)(*data.inputs)

jax/_src/stages.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
from collections.abc import Sequence
3434
from dataclasses import dataclass
35-
from typing import Any, NamedTuple, Protocol, Union
35+
from typing import Any, NamedTuple, Protocol, Union, runtime_checkable
3636
import warnings
3737

3838
import jax
@@ -756,8 +756,9 @@ def cost_analysis(self) -> Any | None:
756756
return None
757757

758758

759+
@runtime_checkable
759760
class Wrapped(Protocol):
760-
"""A function ready to be specialized, lowered, and compiled.
761+
"""A function ready to be traced, lowered, and compiled.
761762
762763
This protocol reflects the output of functions such as
763764
``jax.jit``. Calling it results in JIT (just-in-time) lowering,

jax/experimental/export/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@
1717
minimum_supported_serialization_version,
1818
maximum_supported_serialization_version,
1919
Exported,
20-
export,
2120
call_exported, # TODO: deprecate
2221
call,
2322
DisabledSafetyCheck,
24-
default_lowering_platform,
23+
default_lowering_platform, # TODO: deprecate
2524
)
25+
from jax._src.export._export import export_back_compat as export
26+
2627
from jax._src.export.shape_poly import (
2728
is_symbolic_dim,
2829
symbolic_shape,
@@ -33,4 +34,6 @@
3334
serialize,
3435
deserialize,
3536
)
37+
# Import only to set the shape poly decision procedure
3638
from jax._src.export import shape_poly_decision
39+
del shape_poly_decision

jax/experimental/jax2tf/jax2tf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from jax import numpy as jnp
3737
from jax import tree_util
3838
from jax import sharding
39-
from jax.experimental import export
39+
from jax import export
4040
from jax.experimental.jax2tf import impl_no_xla
4141
from jax.interpreters import xla
4242

@@ -515,7 +515,7 @@ def _restore_context():
515515

516516
self._restore_context = _restore_context
517517
_exported_device_assignment = [None]
518-
self.exported = export.export(
518+
self.exported = _export.export_back_compat(
519519
self.fun_jax,
520520
lowering_platforms=self.native_serialization_platforms,
521521
disabled_checks=self.native_serialization_disabled_checks,

jax/experimental/jax2tf/tests/call_tf_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@
2525
import jax
2626
from jax import dlpack
2727
from jax import dtypes
28+
from jax import export
2829
from jax import lax
2930
from jax import numpy as jnp
3031
from jax._src import config
3132
from jax._src import test_util as jtu
3233
from jax._src.lib.mlir import ir
3334
from jax._src.lib.mlir.dialects import hlo
34-
from jax.experimental import export
3535
from jax.experimental import jax2tf
3636
from jax.experimental.jax2tf.tests import tf_test_util
3737
import numpy as np
@@ -778,7 +778,7 @@ def f_jax(x):
778778

779779
lowering_platforms = ("tpu", "cpu", "cuda")
780780

781-
exp = export.export(f_jax,
781+
exp = export.export(jax.jit(f_jax),
782782
lowering_platforms=lowering_platforms)(x)
783783
for jax_platform in jax_and_tf_platforms:
784784
with self.subTest(jax_platform):
@@ -787,7 +787,7 @@ def f_jax(x):
787787
logging.info("Running harness natively on %s", jax_device)
788788
native_res = f_jax(x_device)
789789
logging.info("Running exported harness on %s", jax_device)
790-
exported_res = export.call(exp)(x_device)
790+
exported_res = exp.call(x_device)
791791
self.assertAllClose(native_res, exported_res)
792792

793793
def test_multi_platform_call_tf_graph(self):

jax/experimental/jax2tf/tests/jax2tf_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import jax
2828
from jax import ad_checkpoint
2929
from jax import dtypes
30+
from jax import export
3031
from jax import lax
3132
from jax import numpy as jnp
3233
from jax import sharding
@@ -37,7 +38,6 @@
3738
from jax._src import test_util as jtu
3839
from jax._src import xla_bridge as xb
3940
from jax.experimental import jax2tf
40-
from jax.experimental import export
4141
from jax.experimental.jax2tf.tests import tf_test_util
4242
from jax.experimental.shard_map import shard_map
4343
from jax.experimental import pjit
@@ -1559,7 +1559,7 @@ def apply_transform(func, transform: str):
15591559
# Run the JAX native version, to check it works, and to fill caches.
15601560
_ = func_to_convert(*args)
15611561
exported = export.export(
1562-
func_to_convert,
1562+
(jax.jit(func_to_convert) if not hasattr(func_to_convert, "trace") else func_to_convert),
15631563
lowering_platforms=("tpu",)
15641564
)(*(core.ShapedArray(a.shape, a.dtype) for a in args))
15651565

jax/experimental/jax2tf/tests/shape_poly_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@
3232

3333
import jax
3434
from jax.experimental import jax2tf
35-
from jax.experimental import export
3635
from jax.experimental import pjit
36+
from jax import export
3737
from jax import lax
3838
import jax.numpy as jnp
3939
from jax import random

0 commit comments

Comments
 (0)