Skip to content

Commit 77883ff

Browse files
[OpenVINO backend] support export model from the supported backends to openvino format (#21486)
* [OpenVINO backend] suppor export model using openvino format * add export for openvino backend * adding tests for openvino export format * fix dynamic shape handling Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * add support for jax backend and avoid to load models on disc for openvino format * support exporting torch backend to openvino format * avoid core dumps by jax * remove redundant code * fix typo * fix example format --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 672c731 commit 77883ff

File tree

5 files changed

+420
-1
lines changed

5 files changed

+420
-1
lines changed

keras/src/backend/openvino/core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,8 @@ def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
625625
dtype = standardize_dtype(type(x))
626626
ov_type = OPENVINO_DTYPES[dtype]
627627
return OpenVINOKerasTensor(ov_opset.constant(x, ov_type).output(0), x)
628+
elif isinstance(x, ov.Output):
629+
return OpenVINOKerasTensor(x)
628630
if isinstance(x, Variable):
629631
x = x.value
630632
if dtype and dtype != x.dtype:

keras/src/export/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from keras.src.export.onnx import export_onnx
2+
from keras.src.export.openvino import export_openvino
23
from keras.src.export.saved_model import ExportArchive
34
from keras.src.export.saved_model import export_saved_model
45
from keras.src.export.tfsm_layer import TFSMLayer

keras/src/export/openvino.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
import warnings
2+
3+
from keras.src import backend
4+
from keras.src import tree
5+
from keras.src.export.export_utils import convert_spec_to_tensor
6+
from keras.src.export.export_utils import get_input_signature
7+
from keras.src.export.export_utils import make_tf_tensor_spec
8+
from keras.src.export.saved_model import DEFAULT_ENDPOINT_NAME
9+
from keras.src.export.saved_model import ExportArchive
10+
from keras.src.utils import io_utils
11+
12+
13+
def export_openvino(
14+
model, filepath, verbose=None, input_signature=None, **kwargs
15+
):
16+
"""Export the model as an OpenVINO IR artifact for inference.
17+
18+
This method exports the model to the OpenVINO IR format,
19+
which includes two files:
20+
a `.xml` file containing the model structure and a `.bin` file
21+
containing the weights.
22+
The exported model contains only the forward pass
23+
(i.e., the model's `call()` method), and can be deployed with the
24+
OpenVINO Runtime for fast inference on CPU and other Intel hardware.
25+
26+
Args:
27+
filepath: `str` or `pathlib.Path`. Path to the output `.xml` file.
28+
The corresponding `.bin` file will be saved alongside it.
29+
verbose: Optional `bool`. Whether to print a confirmation message
30+
after export. If `None`, it uses the default verbosity configured
31+
by the backend.
32+
input_signature: Optional. Specifies the shape and dtype of the
33+
model inputs. If not provided, it will be inferred.
34+
**kwargs: Additional keyword arguments.
35+
36+
Example:
37+
38+
```python
39+
import keras
40+
41+
# Define or load a Keras model
42+
model = keras.models.Sequential([
43+
keras.layers.Input(shape=(128,)),
44+
keras.layers.Dense(64, activation="relu"),
45+
keras.layers.Dense(10)
46+
])
47+
48+
# Export to OpenVINO IR
49+
model.export("model.xml", format="openvino")
50+
```
51+
"""
52+
assert filepath.endswith(".xml"), (
53+
"The OpenVINO export requires the filepath to end with '.xml'. "
54+
f"Got: {filepath}"
55+
)
56+
57+
import openvino as ov
58+
from openvino.runtime import opset14 as ov_opset
59+
60+
from keras.src.backend.openvino.core import OPENVINO_DTYPES
61+
from keras.src.backend.openvino.core import OpenVINOKerasTensor
62+
63+
actual_verbose = verbose if verbose is not None else True
64+
65+
if input_signature is None:
66+
input_signature = get_input_signature(model)
67+
68+
if backend.backend() == "openvino":
69+
import inspect
70+
71+
def parameterize_inputs(inputs, prefix=""):
72+
if isinstance(inputs, (list, tuple)):
73+
return [
74+
parameterize_inputs(e, f"{prefix}{i}")
75+
for i, e in enumerate(inputs)
76+
]
77+
elif isinstance(inputs, dict):
78+
return {k: parameterize_inputs(v, k) for k, v in inputs.items()}
79+
elif isinstance(inputs, OpenVINOKerasTensor):
80+
ov_type = OPENVINO_DTYPES[str(inputs.dtype)]
81+
ov_shape = list(inputs.shape)
82+
param = ov_opset.parameter(shape=ov_shape, dtype=ov_type)
83+
param.set_friendly_name(prefix)
84+
return OpenVINOKerasTensor(param.output(0))
85+
else:
86+
raise TypeError(f"Unknown input type: {type(inputs)}")
87+
88+
if isinstance(input_signature, list) and len(input_signature) == 1:
89+
input_signature = input_signature[0]
90+
91+
sample_inputs = tree.map_structure(
92+
lambda x: convert_spec_to_tensor(x, replace_none_number=1),
93+
input_signature,
94+
)
95+
params = parameterize_inputs(sample_inputs)
96+
signature = inspect.signature(model.call)
97+
if len(signature.parameters) > 1 and isinstance(params, (list, tuple)):
98+
outputs = model(*params)
99+
else:
100+
outputs = model(params)
101+
parameters = [p.output.get_node() for p in tree.flatten(params)]
102+
results = [ov_opset.result(r.output) for r in tree.flatten(outputs)]
103+
ov_model = ov.Model(results=results, parameters=parameters)
104+
flat_specs = tree.flatten(input_signature)
105+
for ov_input, spec in zip(ov_model.inputs, flat_specs):
106+
# Respect the dynamic axes from the original input signature.
107+
dynamic_shape_dims = [
108+
-1 if dim is None else dim for dim in spec.shape
109+
]
110+
dynamic_shape = ov.PartialShape(dynamic_shape_dims)
111+
ov_input.get_node().set_partial_shape(dynamic_shape)
112+
113+
elif backend.backend() in ("tensorflow", "jax"):
114+
inputs = tree.map_structure(make_tf_tensor_spec, input_signature)
115+
decorated_fn = get_concrete_fn(model, inputs, **kwargs)
116+
ov_model = ov.convert_model(decorated_fn)
117+
elif backend.backend() == "torch":
118+
import torch
119+
120+
sample_inputs = tree.map_structure(
121+
lambda x: convert_spec_to_tensor(x, replace_none_number=1),
122+
input_signature,
123+
)
124+
sample_inputs = tuple(sample_inputs)
125+
if hasattr(model, "eval"):
126+
model.eval()
127+
with warnings.catch_warnings():
128+
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
129+
traced = torch.jit.trace(model, sample_inputs)
130+
ov_model = ov.convert_model(traced)
131+
else:
132+
raise NotImplementedError(
133+
"`export_openvino` is only compatible with OpenVINO, "
134+
"TensorFlow, JAX and Torch backends."
135+
)
136+
137+
ov.serialize(ov_model, filepath)
138+
139+
if actual_verbose:
140+
io_utils.print_msg(f"Saved OpenVINO IR at '{filepath}'.")
141+
142+
143+
def _check_jax_kwargs(kwargs):
144+
kwargs = kwargs.copy()
145+
if "is_static" not in kwargs:
146+
kwargs["is_static"] = True
147+
if "jax2tf_kwargs" not in kwargs:
148+
kwargs["jax2tf_kwargs"] = {
149+
"enable_xla": False,
150+
"native_serialization": False,
151+
}
152+
if kwargs["is_static"] is not True:
153+
raise ValueError(
154+
"`is_static` must be `True` in `kwargs` when using the jax backend."
155+
)
156+
if kwargs["jax2tf_kwargs"]["enable_xla"] is not False:
157+
raise ValueError(
158+
"`enable_xla` must be `False` in `kwargs['jax2tf_kwargs']` "
159+
"when using the jax backend."
160+
)
161+
if kwargs["jax2tf_kwargs"]["native_serialization"] is not False:
162+
raise ValueError(
163+
"`native_serialization` must be `False` in "
164+
"`kwargs['jax2tf_kwargs']` when using the jax backend."
165+
)
166+
return kwargs
167+
168+
169+
def get_concrete_fn(model, input_signature, **kwargs):
170+
if backend.backend() == "jax":
171+
kwargs = _check_jax_kwargs(kwargs)
172+
export_archive = ExportArchive()
173+
export_archive.track_and_add_endpoint(
174+
DEFAULT_ENDPOINT_NAME, model, input_signature, **kwargs
175+
)
176+
if backend.backend() == "tensorflow":
177+
export_archive._filter_and_track_resources()
178+
return export_archive._get_concrete_fn(DEFAULT_ENDPOINT_NAME)

0 commit comments

Comments
 (0)