Skip to content

Commit 5a9b26d

Browse files
Rework Model.export and keras.export.ExportArchive to support exporting in TFLite and ONNX formats in the future (#20631)
* Rework `Model.export` and `keras.export.ExportArchive` * Try fixing PyDatasetAdapterTest CI issues
1 parent 32a642d commit 5a9b26d

File tree

14 files changed

+662
-301
lines changed

14 files changed

+662
-301
lines changed

keras/src/backend/jax/export.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
import copy
2+
import inspect
3+
import itertools
4+
import string
5+
import warnings
6+
7+
from keras.src import layers
8+
from keras.src import tree
9+
from keras.src.backend.common.stateless_scope import StatelessScope
10+
from keras.src.utils.module_utils import tensorflow as tf
11+
12+
13+
class JaxExportArchive:
14+
def __init__(self):
15+
self._backend_variables = []
16+
self._backend_trainable_variables = []
17+
self._backend_non_trainable_variables = []
18+
19+
def track(self, resource):
20+
if not isinstance(resource, layers.Layer):
21+
raise ValueError(
22+
"Invalid resource type. Expected an instance of a "
23+
"JAX-based Keras `Layer` or `Model`. "
24+
f"Received instead an object of type '{type(resource)}'. "
25+
f"Object received: {resource}"
26+
)
27+
28+
if isinstance(resource, layers.Layer):
29+
# Variables in the lists below are actually part of the trackables
30+
# that get saved, because the lists are created in __init__.
31+
trainable_variables = resource.trainable_variables
32+
non_trainable_variables = resource.non_trainable_variables
33+
34+
self._tf_trackable.trainable_variables += tree.map_structure(
35+
self._convert_to_tf_variable, trainable_variables
36+
)
37+
self._tf_trackable.non_trainable_variables += tree.map_structure(
38+
self._convert_to_tf_variable, non_trainable_variables
39+
)
40+
self._tf_trackable.variables = (
41+
self._tf_trackable.trainable_variables
42+
+ self._tf_trackable.non_trainable_variables
43+
)
44+
45+
self._backend_trainable_variables += trainable_variables
46+
self._backend_non_trainable_variables += non_trainable_variables
47+
self._backend_variables = (
48+
self._backend_trainable_variables
49+
+ self._backend_non_trainable_variables
50+
)
51+
52+
def add_endpoint(self, name, fn, input_signature=None, **kwargs):
53+
jax2tf_kwargs = kwargs.pop("jax2tf_kwargs", None)
54+
# Use `copy.copy()` to avoid modification issues.
55+
jax2tf_kwargs = copy.copy(jax2tf_kwargs) or {}
56+
is_static = bool(kwargs.pop("is_static", False))
57+
58+
# Configure `jax2tf_kwargs`
59+
if "native_serialization" not in jax2tf_kwargs:
60+
jax2tf_kwargs["native_serialization"] = (
61+
self._check_device_compatible()
62+
)
63+
if "polymorphic_shapes" not in jax2tf_kwargs:
64+
jax2tf_kwargs["polymorphic_shapes"] = self._to_polymorphic_shape(
65+
input_signature
66+
)
67+
68+
# Note: we truncate the number of parameters to what is specified by
69+
# `input_signature`.
70+
fn_signature = inspect.signature(fn)
71+
fn_parameters = list(fn_signature.parameters.values())
72+
73+
if is_static:
74+
from jax.experimental import jax2tf
75+
76+
jax_fn = jax2tf.convert(fn, **jax2tf_kwargs)
77+
jax_fn.__signature__ = inspect.Signature(
78+
parameters=fn_parameters[0 : len(input_signature)],
79+
return_annotation=fn_signature.return_annotation,
80+
)
81+
82+
decorated_fn = tf.function(
83+
jax_fn,
84+
input_signature=input_signature,
85+
autograph=False,
86+
)
87+
else:
88+
# 1. Create a stateless wrapper for `fn`
89+
# 2. jax2tf the stateless wrapper
90+
# 3. Create a stateful function that binds the variables with
91+
# the jax2tf converted stateless wrapper
92+
# 4. Make the signature of the stateful function the same as the
93+
# original function
94+
# 5. Wrap in a `tf.function`
95+
def stateless_fn(variables, *args, **kwargs):
96+
state_mapping = zip(self._backend_variables, variables)
97+
with StatelessScope(state_mapping=state_mapping) as scope:
98+
output = fn(*args, **kwargs)
99+
100+
# Gather updated non-trainable variables
101+
non_trainable_variables = []
102+
for var in self._backend_non_trainable_variables:
103+
new_value = scope.get_current_value(var)
104+
non_trainable_variables.append(new_value)
105+
return output, non_trainable_variables
106+
107+
jax2tf_stateless_fn = self._convert_jax2tf_function(
108+
stateless_fn, input_signature, jax2tf_kwargs=jax2tf_kwargs
109+
)
110+
111+
def stateful_fn(*args, **kwargs):
112+
output, non_trainable_variables = jax2tf_stateless_fn(
113+
# Change the trackable `ListWrapper` to a plain `list`
114+
list(self._tf_trackable.variables),
115+
*args,
116+
**kwargs,
117+
)
118+
for var, new_value in zip(
119+
self._tf_trackable.non_trainable_variables,
120+
non_trainable_variables,
121+
):
122+
var.assign(new_value)
123+
return output
124+
125+
stateful_fn.__signature__ = inspect.Signature(
126+
parameters=fn_parameters[0 : len(input_signature)],
127+
return_annotation=fn_signature.return_annotation,
128+
)
129+
130+
decorated_fn = tf.function(
131+
stateful_fn,
132+
input_signature=input_signature,
133+
autograph=False,
134+
)
135+
return decorated_fn
136+
137+
def _convert_jax2tf_function(self, fn, input_signature, jax2tf_kwargs=None):
138+
from jax.experimental import jax2tf
139+
140+
variables_shapes = self._to_polymorphic_shape(
141+
self._backend_variables, allow_none=False
142+
)
143+
input_shapes = list(jax2tf_kwargs["polymorphic_shapes"])
144+
jax2tf_kwargs["polymorphic_shapes"] = [variables_shapes] + input_shapes
145+
return jax2tf.convert(fn, **jax2tf_kwargs)
146+
147+
def _to_polymorphic_shape(self, struct, allow_none=True):
148+
if allow_none:
149+
# Generates unique names: a, b, ... z, aa, ab, ... az, ba, ... zz
150+
# for unknown non-batch dims. Defined here to be scope per endpoint.
151+
dim_names = itertools.chain(
152+
string.ascii_lowercase,
153+
itertools.starmap(
154+
lambda a, b: a + b,
155+
itertools.product(string.ascii_lowercase, repeat=2),
156+
),
157+
)
158+
159+
def convert_shape(x):
160+
poly_shape = []
161+
for index, dim in enumerate(list(x.shape)):
162+
if dim is not None:
163+
poly_shape.append(str(dim))
164+
elif not allow_none:
165+
raise ValueError(
166+
f"Illegal None dimension in {x} with shape {x.shape}"
167+
)
168+
elif index == 0:
169+
poly_shape.append("batch")
170+
else:
171+
poly_shape.append(next(dim_names))
172+
return "(" + ", ".join(poly_shape) + ")"
173+
174+
return tree.map_structure(convert_shape, struct)
175+
176+
def _check_device_compatible(self):
177+
from jax import default_backend as jax_device
178+
179+
if (
180+
jax_device() == "gpu"
181+
and len(tf.config.list_physical_devices("GPU")) == 0
182+
):
183+
warnings.warn(
184+
"JAX backend is using GPU for export, but installed "
185+
"TF package cannot access GPU, so reloading the model with "
186+
"the TF runtime in the same environment will not work. "
187+
"To use JAX-native serialization for high-performance export "
188+
"and serving, please install `tensorflow-gpu` and ensure "
189+
"CUDA version compatibility between your JAX and TF "
190+
"installations."
191+
)
192+
return False
193+
else:
194+
return True

keras/src/backend/numpy/export.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
class NumpyExportArchive:
2+
def track(self, resource):
3+
raise NotImplementedError(
4+
"`track` is not implemented in the numpy backend."
5+
)
6+
7+
def add_endpoint(self, name, fn, input_signature=None, **kwargs):
8+
raise NotImplementedError(
9+
"`add_endpoint` is not implemented in the numpy backend."
10+
)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import tensorflow as tf
2+
3+
from keras.src import layers
4+
5+
6+
class TFExportArchive:
7+
def track(self, resource):
8+
if not isinstance(resource, tf.__internal__.tracking.Trackable):
9+
raise ValueError(
10+
"Invalid resource type. Expected an instance of a "
11+
"TensorFlow `Trackable` (such as a Keras `Layer` or `Model`). "
12+
f"Received instead an object of type '{type(resource)}'. "
13+
f"Object received: {resource}"
14+
)
15+
16+
if isinstance(resource, layers.Layer):
17+
# Variables in the lists below are actually part of the trackables
18+
# that get saved, because the lists are created in __init__.
19+
variables = resource.variables
20+
trainable_variables = resource.trainable_variables
21+
non_trainable_variables = resource.non_trainable_variables
22+
self._tf_trackable.variables += variables
23+
self._tf_trackable.trainable_variables += trainable_variables
24+
self._tf_trackable.non_trainable_variables += (
25+
non_trainable_variables
26+
)
27+
28+
def add_endpoint(self, name, fn, input_signature=None, **kwargs):
29+
decorated_fn = tf.function(
30+
fn, input_signature=input_signature, autograph=False
31+
)
32+
return decorated_fn

keras/src/backend/torch/export.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from keras.src import layers
2+
from keras.src import tree
3+
4+
5+
class TorchExportArchive:
6+
def track(self, resource):
7+
if not isinstance(resource, layers.Layer):
8+
raise ValueError(
9+
"Invalid resource type. Expected an instance of a "
10+
"JAX-based Keras `Layer` or `Model`. "
11+
f"Received instead an object of type '{type(resource)}'. "
12+
f"Object received: {resource}"
13+
)
14+
15+
if isinstance(resource, layers.Layer):
16+
# Variables in the lists below are actually part of the trackables
17+
# that get saved, because the lists are created in __init__.
18+
variables = resource.variables
19+
trainable_variables = resource.trainable_variables
20+
non_trainable_variables = resource.non_trainable_variables
21+
self._tf_trackable.variables += tree.map_structure(
22+
self._convert_to_tf_variable, variables
23+
)
24+
self._tf_trackable.trainable_variables += tree.map_structure(
25+
self._convert_to_tf_variable, trainable_variables
26+
)
27+
self._tf_trackable.non_trainable_variables += tree.map_structure(
28+
self._convert_to_tf_variable, non_trainable_variables
29+
)
30+
31+
def add_endpoint(self, name, fn, input_signature=None, **kwargs):
32+
# TODO: torch-xla?
33+
raise NotImplementedError(
34+
"`add_endpoint` is not implemented in the torch backend."
35+
)

0 commit comments

Comments
 (0)