Skip to content

Commit 14d87d3

Browse files
gneculajax authors
authored andcommitted
[export] Move the export implementation to jax._src.export.
This is part of the work to move the export APIs out of jax.experimental. For now, the way to use this implementation is still through `jax.experimental.export`. Had to add a few "#type ignore" to the _export.py because previously the file was exempt from internal pytype. Will try to fix these in a later PR. PiperOrigin-RevId: 641688200
1 parent aaa559a commit 14d87d3

File tree

10 files changed

+52
-39
lines changed

10 files changed

+52
-39
lines changed

jax/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ py_library_providing_imports_info(
321321
":xla",
322322
":xla_bridge",
323323
"//jax/_src/lib",
324-
] + py_deps("numpy") + py_deps("scipy") + py_deps("opt_einsum") + jax_extra_deps,
324+
] + py_deps("numpy") + py_deps("scipy") + py_deps("opt_einsum") + py_deps("flatbuffers") + jax_extra_deps,
325325
)
326326

327327
pytype_strict_library(

jax/_src/export/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2024 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

jax/experimental/export/_export.py renamed to jax/_src/export/_export.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ def do_export(*args_specs, **kwargs_specs) -> Exported:
460460
actual_lowering_platforms = (default_lowering_platform(),)
461461

462462
# TODO: move to `lower`
463-
symbolic_scope: tuple[shape_poly.SymbolicScope, tree_util.KeyPath] | None = None
463+
symbolic_scope: tuple[shape_poly.SymbolicScope, tree_util.KeyPath] | None = None # type: ignore[invalid-annotation,unused-ignore]
464464
for k_path, aval in tree_util.tree_flatten_with_path((args_specs, kwargs_specs))[0]:
465465
# Static args may have no `shape` attribute.
466466
if not hasattr(aval, "shape"):
@@ -476,7 +476,7 @@ def do_export(*args_specs, **kwargs_specs) -> Exported:
476476
other_descr=shape_poly.args_kwargs_path_to_str(k_path))
477477

478478
if has_trace:
479-
traced = wrapped_fun_jax.trace(
479+
traced = wrapped_fun_jax.trace( # type: ignore
480480
*args_specs, **kwargs_specs,
481481
_experimental_lowering_parameters=mlir.LoweringParameters(
482482
platforms=actual_lowering_platforms,
@@ -547,7 +547,7 @@ def _export_lowered(
547547
elif "shards" in lowering.compile_args: # for PmapComputation
548548
out_avals_flat = lowering.compile_args["shards"].out_sharded_avals
549549
else:
550-
out_avals_flat = lowered.compile_args["out_avals"]
550+
out_avals_flat = lowered.compile_args["out_avals"] # type: ignore
551551

552552
# Log and then check the module.
553553
if logging.vlog_is_on(3):
@@ -612,7 +612,7 @@ def _get_exported_vjp(exp_primal: Exported) -> Exported:
612612
in_shardings_hlo=in_shardings,
613613
out_shardings_hlo=out_shardings,
614614
nr_devices=nr_devices,
615-
lowering_platforms=lowering._platforms,
615+
lowering_platforms=lowering._platforms, # type: ignore
616616
ordered_effects=ordered_effects,
617617
unordered_effects=unordered_effects,
618618
disabled_safety_checks=tuple(disabled_checks),
@@ -641,7 +641,7 @@ def _module_to_bytecode(module: ir.Module) -> bytes:
641641
# and still have the payloads produced by `serialize_portable_artifact`
642642
# compatible with potential consumers from the past.
643643
target_version = hlo.get_minimum_version()
644-
module_serialized = xla_client._xla.mlir.serialize_portable_artifact(
644+
module_serialized = xla_client._xla.mlir.serialize_portable_artifact( # type: ignore
645645
mlir_str, target_version)
646646
return module_serialized
647647

@@ -688,8 +688,8 @@ def _wrap_main_func(
688688
def is_token(typ, attrs):
689689
return (typ == mlir.token_type()[0])
690690

691-
orig_input_types = orig_main.type.inputs
692-
arg_attrs = list(ir.ArrayAttr(orig_main.arg_attrs))
691+
orig_input_types = orig_main.type.inputs # type: ignore
692+
arg_attrs = list(ir.ArrayAttr(orig_main.arg_attrs)) # type: ignore
693693
# The order of args: platform_index_arg, dim args, token args, array args.
694694
nr_platform_index_args = 1 if has_platform_index_argument else 0
695695
nr_dim_args = len(dim_vars)
@@ -711,8 +711,8 @@ def is_token(typ, attrs):
711711
orig_input_types, [nr_platform_index_args, nr_dim_args, nr_token_args])
712712

713713
# The order of results: tokens, array results
714-
orig_output_types = orig_main.type.results
715-
result_attrs = list(ir.ArrayAttr(orig_main.result_attrs))
714+
orig_output_types = orig_main.type.results # type: ignore
715+
result_attrs = list(ir.ArrayAttr(orig_main.result_attrs)) # type: ignore
716716
token_result_idxs = [i for i, (typ, attrs) in enumerate(zip(orig_output_types,
717717
result_attrs))
718718
if is_token(typ, attrs)]
@@ -1138,6 +1138,8 @@ def _call_exported_abstract_eval(
11381138
assert len(in_avals) == len(exported.in_avals) # since the pytrees have the same structure
11391139
# Check that the expected shapes match the actual ones
11401140
for arg_idx, (exp_aval, actual_aval) in enumerate(zip(exported.in_avals, in_avals)):
1141+
exp_aval: core.ShapedArray = exp_aval # type: ignore
1142+
actual_aval: core.ShapedArray = actual_aval # type: ignore
11411143
def pp_arg_dim(dim_idx: int | None) -> str:
11421144
return shape_poly.pretty_print_dimension_descriptor(exported.in_tree,
11431145
arg_idx, dim_idx)
@@ -1181,10 +1183,10 @@ def pp_arg_dim(dim_idx: int | None) -> str:
11811183
exported_dim_values = [synthetic_eval.evaluate(solution[var])
11821184
for var in exported_dim_vars]
11831185
out_avals = tuple(
1184-
core.ShapedArray(core.evaluate_shape(out_aval.shape, exported_dim_vars,
1186+
core.ShapedArray(core.evaluate_shape(out_aval.shape, exported_dim_vars, # type: ignore
11851187
*exported_dim_values),
1186-
dtype=out_aval.dtype, weak_type=out_aval.weak_type,
1187-
named_shape=out_aval.named_shape)
1188+
dtype=out_aval.dtype, weak_type=out_aval.weak_type, # type: ignore
1189+
named_shape=out_aval.named_shape) # type: ignore
11881190
for out_aval in exported.out_avals)
11891191
return out_avals, set(exported.ordered_effects + exported.unordered_effects)
11901192

jax/experimental/export/serialization.fbs renamed to jax/_src/export/serialization.fbs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
// 3. Add back the licence comment at the start
2121
//
2222

23-
namespace jax.experimental.export.serialization;
23+
namespace jax.export.serialization;
2424

2525
enum PyTreeDefKind: byte {
2626
leaf = 0,

jax/experimental/export/_serialization.py renamed to jax/_src/export/serialization.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
# Serialization and deserialization of export.Exported
15+
# Serialization and deserialization of _export.Exported
1616

1717
from __future__ import annotations
1818

@@ -31,10 +31,10 @@
3131
from jax._src import dtypes
3232
from jax._src import effects
3333
from jax._src import tree_util
34+
from jax._src.export import serialization_generated as ser_flatbuf
35+
from jax._src.export import _export
36+
from jax._src.export import shape_poly
3437
from jax._src.lib import xla_client
35-
from jax.experimental.export import serialization_generated as ser_flatbuf
36-
from jax.experimental.export import _export
37-
from jax.experimental import export
3838

3939
import numpy as np
4040

@@ -47,7 +47,7 @@
4747
# Version 2, Dec 16th, 2023, adds the f0 dtype.
4848
_SERIALIZATION_VERSION = 2
4949

50-
def serialize(exp: export.Exported, vjp_order: int = 0) -> bytearray:
50+
def serialize(exp: _export.Exported, vjp_order: int = 0) -> bytearray:
5151
"""Serialize an Exported.
5252
5353
Args:
@@ -63,14 +63,14 @@ def serialize(exp: export.Exported, vjp_order: int = 0) -> bytearray:
6363
return builder.Output()
6464

6565

66-
def deserialize(ser: bytearray) -> export.Exported:
66+
def deserialize(ser: bytearray) -> _export.Exported:
6767
"""Deserialize an Exported."""
6868
exp = ser_flatbuf.Exported.GetRootAsExported(ser)
6969
return _deserialize_exported(exp)
7070

7171

7272
def _serialize_exported(
73-
builder: flatbuffers.Builder, exp: export.Exported, vjp_order: int
73+
builder: flatbuffers.Builder, exp: _export.Exported, vjp_order: int
7474
) -> int:
7575
# Serialize bottom-up
7676
fun_name = builder.CreateString(exp.fun_name)
@@ -150,7 +150,7 @@ def _serialize_array(
150150
return builder.EndVector()
151151

152152

153-
def _deserialize_exported(exp: ser_flatbuf.Exported) -> export.Exported:
153+
def _deserialize_exported(exp: ser_flatbuf.Exported) -> _export.Exported:
154154
serialization_version = exp.SerializationVersion()
155155
if serialization_version != _SERIALIZATION_VERSION:
156156
raise NotImplementedError(
@@ -161,7 +161,7 @@ def _deserialize_exported(exp: ser_flatbuf.Exported) -> export.Exported:
161161
_, in_tree = tree_util.tree_flatten(
162162
_deserialize_pytreedef_to_pytree(exp.InTree())
163163
)
164-
scope = export.SymbolicScope(()) # TODO: serialize the constraints
164+
scope = shape_poly.SymbolicScope(()) # TODO: serialize the constraints
165165
deser_aval = partial(_deserialize_aval, scope=scope)
166166
in_avals = _deserialize_tuple(
167167
exp.InAvalsLength, exp.InAvals, deser_aval
@@ -205,7 +205,7 @@ def _deserialize_exported(exp: ser_flatbuf.Exported) -> export.Exported:
205205
if vjp := exp.Vjp():
206206
_get_vjp = lambda _: _deserialize_exported(vjp)
207207

208-
return export.Exported(
208+
return _export.Exported(
209209
fun_name=fun_name,
210210
in_tree=in_tree,
211211
in_avals=in_avals,
@@ -356,7 +356,7 @@ def _deserialize_aval(aval: ser_flatbuf.AbstractValue,
356356
aval_kind = aval.Kind()
357357
if aval_kind == ser_flatbuf.AbstractValueKind.shapedArray:
358358
dtype = _dtype_kind_to_dtype[aval.Dtype()]
359-
shape = export.symbolic_shape(
359+
shape = shape_poly.symbolic_shape(
360360
",".join(
361361
aval.Shape(i).decode("utf-8") for i in range(aval.ShapeLength())
362362
),
@@ -445,16 +445,16 @@ def _deserialize_effect(eff: ser_flatbuf.Effect) -> core.Effect:
445445

446446

447447
def _serialize_disabled_safety_check(
448-
builder: flatbuffers.Builder, check: export.DisabledSafetyCheck
448+
builder: flatbuffers.Builder, check: _export.DisabledSafetyCheck
449449
) -> int:
450450
custom_call_target_str = check.is_custom_call()
451451
custom_call_target = None
452452
if custom_call_target_str is not None:
453453
kind = ser_flatbuf.DisabledSafetyCheckKind.custom_call
454454
custom_call_target = builder.CreateString(custom_call_target_str)
455-
elif check == export.DisabledSafetyCheck.platform():
455+
elif check == _export.DisabledSafetyCheck.platform():
456456
kind = ser_flatbuf.DisabledSafetyCheckKind.platform
457-
elif check == export.DisabledSafetyCheck.shape_assertions():
457+
elif check == _export.DisabledSafetyCheck.shape_assertions():
458458
kind = ser_flatbuf.DisabledSafetyCheckKind.shape_assertions
459459
else:
460460
raise NotImplementedError(f"serializing DisabledSafetyCheck: {check}")
@@ -470,14 +470,14 @@ def _serialize_disabled_safety_check(
470470

471471
def _deserialize_disabled_safety_check(
472472
sc: ser_flatbuf.DisabledSafetyCheck,
473-
) -> export.DisabledSafetyCheck:
473+
) -> _export.DisabledSafetyCheck:
474474
kind = sc.Kind()
475475
if kind == ser_flatbuf.DisabledSafetyCheckKind.custom_call:
476-
return export.DisabledSafetyCheck.custom_call(
476+
return _export.DisabledSafetyCheck.custom_call(
477477
sc.CustomCallTarget().decode("utf-8")
478478
)
479479
if kind == ser_flatbuf.DisabledSafetyCheckKind.platform:
480-
return export.DisabledSafetyCheck.platform()
480+
return _export.DisabledSafetyCheck.platform()
481481
if kind == ser_flatbuf.DisabledSafetyCheckKind.shape_assertions:
482-
return export.DisabledSafetyCheck.shape_assertions()
482+
return _export.DisabledSafetyCheck.shape_assertions()
483483
assert False, kind

jax/experimental/export/serialization_generated.py renamed to jax/_src/export/serialization_generated.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
# pytype: skip-file
1516
# automatically generated by the FlatBuffers compiler, do not modify
1617

1718
# namespace: serialization

jax/experimental/export/BUILD

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,6 @@ py_library(
3131
name = "export",
3232
srcs = [
3333
"__init__.py",
34-
"_export.py",
35-
"_serialization.py",
36-
"serialization_generated.py",
3734
],
3835
srcs_version = "PY3",
3936
# TODO: b/255503696: enable pytype

jax/experimental/export/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515

16-
from jax.experimental.export._export import (
16+
from jax._src.export._export import (
1717
minimum_supported_serialization_version,
1818
maximum_supported_serialization_version,
1919
Exported,
@@ -29,7 +29,7 @@
2929
symbolic_args_specs,
3030
SymbolicScope,
3131
)
32-
from jax.experimental.export._serialization import (
32+
from jax._src.export.serialization import (
3333
serialize,
3434
deserialize,
3535
)

jax/experimental/jax2tf/jax2tf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
from jax import tree_util
3838
from jax import sharding
3939
from jax.experimental import export
40-
from jax.experimental.export import _export
4140
from jax.experimental.jax2tf import impl_no_xla
4241
from jax.interpreters import xla
4342

@@ -60,6 +59,7 @@
6059
from jax._src import source_info_util
6160
from jax._src import util
6261
from jax._src import shard_alike
62+
from jax._src.export import _export
6363
from jax._src.export import shape_poly
6464
from jax._src.interpreters import ad
6565
from jax._src.interpreters import mlir

tests/export_back_compat_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
import jax
2929
from jax import lax
30-
from jax.experimental.export import _export
30+
from jax._src.export import _export
3131

3232
from jax._src.internal_test_util import export_back_compat_test_util as bctu
3333

0 commit comments

Comments
 (0)