Skip to content

Commit 67d1ddf

Browse files
Introduces support for exporting SavedModel in the torch backend using torch-xla (#20685)
* Add support for exporting savedmodel in the torch backend * Fix `actions.yml` * Fix CI * Remove unused `_mangle_tf_root_scope_name` and add `import_error_msg` to `LazyModule` * Ignore `export_lib_test` in torch GPU CI
1 parent ca58091 commit 67d1ddf

File tree

11 files changed

+334
-92
lines changed

11 files changed

+334
-92
lines changed

.github/workflows/actions.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ jobs:
5454
fi
5555
pip install -r $REQUIREMENTS_FILE --progress-bar off --upgrade
5656
pip uninstall -y keras keras-nightly
57-
pip install tf_keras==2.16.0 --progress-bar off --upgrade
5857
pip install -e "." --progress-bar off --upgrade
5958
- name: Test applications with pytest
6059
if: ${{ steps.filter.outputs.applications == 'true' }}

.kokoro/github/ubuntu/gpu/build.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,9 @@ then
7272
# Raise error if GPU is not detected.
7373
python3 -c 'import torch;assert torch.cuda.is_available()'
7474

75+
# TODO: keras/src/export/export_lib_test.py update LD_LIBRARY_PATH
7576
pytest keras --ignore keras/src/applications \
77+
--ignore keras/src/export/export_lib_test.py \
7678
--cov=keras \
7779
--cov-config=pyproject.toml
7880

keras/src/backend/torch/export.py

Lines changed: 132 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,142 @@
1-
from keras.src import layers
1+
import copy
2+
import warnings
3+
4+
import torch
5+
6+
from keras.src import backend
7+
from keras.src import ops
28
from keras.src import tree
9+
from keras.src.utils.module_utils import tensorflow as tf
10+
from keras.src.utils.module_utils import torch_xla
311

412

513
class TorchExportArchive:
614
def track(self, resource):
7-
if not isinstance(resource, layers.Layer):
8-
raise ValueError(
9-
"Invalid resource type. Expected an instance of a "
10-
"JAX-based Keras `Layer` or `Model`. "
11-
f"Received instead an object of type '{type(resource)}'. "
12-
f"Object received: {resource}"
13-
)
15+
raise NotImplementedError(
16+
"`track` is not implemented in the torch backend. Use"
17+
"`track_and_add_endpoint` instead."
18+
)
1419

15-
if isinstance(resource, layers.Layer):
16-
# Variables in the lists below are actually part of the trackables
17-
# that get saved, because the lists are created in __init__.
18-
variables = resource.variables
19-
trainable_variables = resource.trainable_variables
20-
non_trainable_variables = resource.non_trainable_variables
21-
self._tf_trackable.variables += tree.map_structure(
22-
self._convert_to_tf_variable, variables
23-
)
24-
self._tf_trackable.trainable_variables += tree.map_structure(
25-
self._convert_to_tf_variable, trainable_variables
20+
def add_endpoint(self, name, fn, input_signature, **kwargs):
21+
raise NotImplementedError(
22+
"`add_endpoint` is not implemented in the torch backend. Use"
23+
"`track_and_add_endpoint` instead."
24+
)
25+
26+
def track_and_add_endpoint(self, name, resource, input_signature, **kwargs):
27+
# Disable false alarms related to lifting parameters.
28+
warnings.filterwarnings("ignore", message=".*created when tracing.*")
29+
warnings.filterwarnings(
30+
"ignore", message=".*Unable to find the path of the module.*"
31+
)
32+
33+
if not isinstance(resource, torch.nn.Module):
34+
raise TypeError(
35+
"`resource` must be an instance of `torch.nn.Module`. "
36+
f"Received: resource={resource} (of type {type(resource)})"
2637
)
27-
self._tf_trackable.non_trainable_variables += tree.map_structure(
28-
self._convert_to_tf_variable, non_trainable_variables
38+
39+
def _check_input_signature(input_spec):
40+
for s in tree.flatten(input_spec.shape):
41+
if s is None:
42+
raise ValueError(
43+
"The shape in the `input_spec` must be fully "
44+
f"specified. Received: input_spec={input_spec}"
45+
)
46+
47+
def _to_torch_tensor(x, replace_none_number=1):
48+
shape = backend.standardize_shape(x.shape)
49+
shape = tuple(
50+
s if s is not None else replace_none_number for s in shape
2951
)
52+
return ops.ones(shape, x.dtype)
3053

31-
def add_endpoint(self, name, fn, input_signature=None, **kwargs):
32-
# TODO: torch-xla?
33-
raise NotImplementedError(
34-
"`add_endpoint` is not implemented in the torch backend."
54+
tree.map_structure(_check_input_signature, input_signature)
55+
sample_inputs = tree.map_structure(_to_torch_tensor, input_signature)
56+
sample_inputs = tuple(sample_inputs)
57+
58+
# Ref: torch_xla.tf_saved_model_integration
59+
# TODO: Utilize `dynamic_shapes`
60+
exported = torch.export.export(
61+
resource, sample_inputs, dynamic_shapes=None, strict=False
62+
)
63+
options = torch_xla.stablehlo.StableHLOExportOptions(
64+
override_tracing_arguments=sample_inputs
65+
)
66+
stablehlo_model = torch_xla.stablehlo.exported_program_to_stablehlo(
67+
exported, options
68+
)
69+
state_dict_keys = list(stablehlo_model._bundle.state_dict.keys())
70+
71+
# Remove unused variables.
72+
for k in state_dict_keys:
73+
if "lifted" not in k:
74+
stablehlo_model._bundle.state_dict.pop(k)
75+
76+
bundle = copy.deepcopy(stablehlo_model._bundle)
77+
bundle.state_dict = {
78+
k: tf.Variable(v, trainable=False, name=k)
79+
for k, v in bundle.state_dict.items()
80+
}
81+
bundle.additional_constants = [
82+
tf.Variable(v, trainable=False) for v in bundle.additional_constants
83+
]
84+
85+
# Track variables in `bundle` for `write_out`.
86+
self._tf_trackable.variables += (
87+
list(bundle.state_dict.values()) + bundle.additional_constants
88+
)
89+
90+
# Ref: torch_xla.tf_saved_model_integration.save_stablehlo_graph_as_tf
91+
def make_tf_function(func, bundle):
92+
from tensorflow.compiler.tf2xla.python import xla as tfxla
93+
94+
def _get_shape_with_dynamic(signature):
95+
shape = copy.copy(signature.shape)
96+
for i in signature.dynamic_dims:
97+
shape[i] = None
98+
return shape
99+
100+
def _extract_call_parameters(args, meta, bundle):
101+
call_args = []
102+
if meta.input_pytree_spec is not None:
103+
args = tree.flatten(args)
104+
for loc in meta.input_locations:
105+
if loc.type_ == torch_xla.stablehlo.VariableType.PARAMETER:
106+
call_args.append(bundle.state_dict[loc.name])
107+
elif loc.type_ == torch_xla.stablehlo.VariableType.CONSTANT:
108+
call_args.append(
109+
bundle.additional_constants[loc.position]
110+
)
111+
else:
112+
call_args.append(args[loc.position])
113+
return call_args
114+
115+
def inner(*args):
116+
Touts = [sig.dtype for sig in func.meta.output_signature]
117+
Souts = [
118+
_get_shape_with_dynamic(sig)
119+
for sig in func.meta.output_signature
120+
]
121+
call_args = _extract_call_parameters(args, func.meta, bundle)
122+
results = tfxla.call_module(
123+
tuple(call_args),
124+
version=5,
125+
Tout=Touts, # dtype information
126+
Sout=Souts, # Shape information
127+
function_list=[],
128+
module=func.bytecode,
129+
)
130+
if len(Souts) == 1:
131+
results = results[0]
132+
return results
133+
134+
return inner
135+
136+
decorated_fn = tf.function(
137+
make_tf_function(
138+
stablehlo_model._bundle.stablehlo_funcs[0], bundle
139+
),
140+
input_signature=input_signature,
35141
)
142+
return decorated_fn

keras/src/export/export_lib.py

Lines changed: 100 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class ExportArchive(BackendExportArchive):
9191
9292
**Note on resource tracking:**
9393
94-
`ExportArchive` is able to automatically track all `tf.Variables` used
94+
`ExportArchive` is able to automatically track all `keras.Variables` used
9595
by its endpoints, so most of the time calling `.track(model)`
9696
is not strictly required. However, if your model uses lookup layers such
9797
as `IntegerLookup`, `StringLookup`, or `TextVectorization`,
@@ -104,9 +104,10 @@ class ExportArchive(BackendExportArchive):
104104

105105
def __init__(self):
106106
super().__init__()
107-
if backend.backend() not in ("tensorflow", "jax"):
107+
if backend.backend() not in ("tensorflow", "jax", "torch"):
108108
raise NotImplementedError(
109-
"The export API is only compatible with JAX and TF backends."
109+
"`ExportArchive` is only compatible with TensorFlow, JAX and "
110+
"Torch backends."
110111
)
111112

112113
self._endpoint_names = []
@@ -141,8 +142,8 @@ def track(self, resource):
141142
(`TextVectorization`, `IntegerLookup`, `StringLookup`)
142143
are automatically tracked in `add_endpoint()`.
143144
144-
Arguments:
145-
resource: A trackable TensorFlow resource.
145+
Args:
146+
resource: A trackable Keras resource, such as a layer or model.
146147
"""
147148
if isinstance(resource, layers.Layer) and not resource.built:
148149
raise ValueError(
@@ -334,12 +335,78 @@ def serving_fn(x):
334335
self._endpoint_names.append(name)
335336
return decorated_fn
336337

338+
def track_and_add_endpoint(self, name, resource, input_signature, **kwargs):
339+
"""Track the variables and register a new serving endpoint.
340+
341+
This function combines the functionality of `track` and `add_endpoint`.
342+
It tracks the variables of the `resource` (either a layer or a model)
343+
and registers a serving endpoint using `resource.__call__`.
344+
345+
Args:
346+
name: `str`. The name of the endpoint.
347+
resource: A trackable Keras resource, such as a layer or model.
348+
input_signature: Optional. Specifies the shape and dtype of `fn`.
349+
Can be a structure of `keras.InputSpec`, `tf.TensorSpec`,
350+
`backend.KerasTensor`, or backend tensor (see below for an
351+
example showing a `Functional` model with 2 input arguments). If
352+
not provided, `fn` must be a `tf.function` that has been called
353+
at least once. Defaults to `None`.
354+
**kwargs: Additional keyword arguments:
355+
- Specific to the JAX backend:
356+
- `is_static`: Optional `bool`. Indicates whether `fn` is
357+
static. Set to `False` if `fn` involves state updates
358+
(e.g., RNG seeds).
359+
- `jax2tf_kwargs`: Optional `dict`. Arguments for
360+
`jax2tf.convert`. See [`jax2tf.convert`](
361+
https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).
362+
If `native_serialization` and `polymorphic_shapes` are
363+
not provided, they are automatically computed.
364+
365+
"""
366+
if name in self._endpoint_names:
367+
raise ValueError(f"Endpoint name '{name}' is already taken.")
368+
if not isinstance(resource, layers.Layer):
369+
raise ValueError(
370+
"Invalid resource type. Expected an instance of a Keras "
371+
"`Layer` or `Model`. "
372+
f"Received: resource={resource} (of type {type(resource)})"
373+
)
374+
if not resource.built:
375+
raise ValueError(
376+
"The layer provided has not yet been built. "
377+
"It must be built before export."
378+
)
379+
if backend.backend() != "jax":
380+
if "jax2tf_kwargs" in kwargs or "is_static" in kwargs:
381+
raise ValueError(
382+
"'jax2tf_kwargs' and 'is_static' are only supported with "
383+
f"the jax backend. Current backend: {backend.backend()}"
384+
)
385+
386+
input_signature = tree.map_structure(_make_tensor_spec, input_signature)
387+
388+
if not hasattr(BackendExportArchive, "track_and_add_endpoint"):
389+
# Default behavior.
390+
self.track(resource)
391+
return self.add_endpoint(
392+
name, resource.__call__, input_signature, **kwargs
393+
)
394+
else:
395+
# Special case for the torch backend.
396+
decorated_fn = BackendExportArchive.track_and_add_endpoint(
397+
self, name, resource, input_signature, **kwargs
398+
)
399+
self._endpoint_signatures[name] = input_signature
400+
setattr(self._tf_trackable, name, decorated_fn)
401+
self._endpoint_names.append(name)
402+
return decorated_fn
403+
337404
def add_variable_collection(self, name, variables):
338405
"""Register a set of variables to be retrieved after reloading.
339406
340407
Arguments:
341408
name: The string name for the collection.
342-
variables: A tuple/list/set of `tf.Variable` instances.
409+
variables: A tuple/list/set of `keras.Variable` instances.
343410
344411
Example:
345412
@@ -496,9 +563,6 @@ def export_saved_model(
496563
):
497564
"""Export the model as a TensorFlow SavedModel artifact for inference.
498565
499-
**Note:** This feature is currently supported only with TensorFlow and
500-
JAX backends.
501-
502566
This method lets you export a model to a lightweight SavedModel artifact
503567
that contains the model's forward pass only (its `call()` method)
504568
and can be served via e.g. TensorFlow Serving. The forward pass is
@@ -527,6 +591,14 @@ def export_saved_model(
527591
If `native_serialization` and `polymorphic_shapes` are not
528592
provided, they are automatically computed.
529593
594+
**Note:** This feature is currently supported only with TensorFlow, JAX and
595+
Torch backends. Support for the Torch backend is experimental.
596+
597+
**Note:** The dynamic shape feature is not yet supported with Torch
598+
backend. As a result, you must fully define the shapes of the inputs using
599+
`input_signature`. If `input_signature` is not provided, all instances of
600+
`None` (such as the batch size) will be replaced with `1`.
601+
530602
Example:
531603
532604
```python
@@ -543,28 +615,29 @@ def export_saved_model(
543615
`export()` method relies on `ExportArchive` internally.
544616
"""
545617
export_archive = ExportArchive()
546-
export_archive.track(model)
547-
if isinstance(model, (Functional, Sequential)):
548-
if input_signature is None:
618+
if input_signature is None:
619+
if not model.built:
620+
raise ValueError(
621+
"The layer provided has not yet been built. "
622+
"It must be built before export."
623+
)
624+
if isinstance(model, (Functional, Sequential)):
549625
input_signature = tree.map_structure(
550626
_make_tensor_spec, model.inputs
551627
)
552-
if isinstance(input_signature, list) and len(input_signature) > 1:
553-
input_signature = [input_signature]
554-
export_archive.add_endpoint(
555-
"serve", model.__call__, input_signature, **kwargs
556-
)
557-
else:
558-
if input_signature is None:
628+
if isinstance(input_signature, list) and len(input_signature) > 1:
629+
input_signature = [input_signature]
630+
else:
559631
input_signature = _get_input_signature(model)
560-
if not input_signature or not model._called:
561-
raise ValueError(
562-
"The model provided has never called. "
563-
"It must be called at least once before export."
564-
)
565-
export_archive.add_endpoint(
566-
"serve", model.__call__, input_signature, **kwargs
567-
)
632+
if not input_signature or not model._called:
633+
raise ValueError(
634+
"The model provided has never called. "
635+
"It must be called at least once before export."
636+
)
637+
638+
export_archive.track_and_add_endpoint(
639+
"serve", model, input_signature, **kwargs
640+
)
568641
export_archive.write_out(filepath, verbose=verbose)
569642

570643

0 commit comments

Comments
 (0)