Skip to content

Commit 4bdda42

Browse files
Add opset_version in export_onnx. (#21541)
1 parent 92aefef commit 4bdda42

File tree

3 files changed

+63
-15
lines changed

3 files changed

+63
-15
lines changed

keras/src/export/onnx.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,14 @@
1111
from keras.src.utils import io_utils
1212

1313

14-
def export_onnx(model, filepath, verbose=None, input_signature=None, **kwargs):
14+
def export_onnx(
15+
model,
16+
filepath,
17+
verbose=None,
18+
input_signature=None,
19+
opset_version=None,
20+
**kwargs,
21+
):
1522
"""Export the model as a ONNX artifact for inference.
1623
1724
This method lets you export a model to a lightweight ONNX artifact
@@ -31,6 +38,9 @@ def export_onnx(model, filepath, verbose=None, input_signature=None, **kwargs):
3138
inputs. Can be a structure of `keras.InputSpec`, `tf.TensorSpec`,
3239
`backend.KerasTensor`, or backend tensor. If not provided, it will
3340
be automatically computed. Defaults to `None`.
41+
opset_version: Optional. An integer value that specifies the ONNX opset
42+
version. If not provided, the default version for the backend will
43+
be used. Defaults to `None`.
3444
**kwargs: Additional keyword arguments.
3545
3646
**Note:** This feature is currently supported only with TensorFlow, JAX and
@@ -82,7 +92,10 @@ def export_onnx(model, filepath, verbose=None, input_signature=None, **kwargs):
8292
# Use `tf2onnx` to convert the `decorated_fn` to the ONNX format.
8393
patch_tf2onnx() # TODO: Remove this once `tf2onnx` supports numpy 2.
8494
tf2onnx.convert.from_function(
85-
decorated_fn, input_signature, output_path=filepath
95+
decorated_fn,
96+
input_signature,
97+
opset=opset_version,
98+
output_path=filepath,
8699
)
87100

88101
elif backend.backend() == "torch":
@@ -126,7 +139,11 @@ def export_onnx(model, filepath, verbose=None, input_signature=None, **kwargs):
126139
try:
127140
# Try the TorchDynamo-based ONNX exporter first.
128141
onnx_program = torch.onnx.export(
129-
model, sample_inputs, verbose=actual_verbose, dynamo=True
142+
model,
143+
sample_inputs,
144+
verbose=actual_verbose,
145+
opset_version=opset_version,
146+
dynamo=True,
130147
)
131148
if hasattr(onnx_program, "optimize"):
132149
onnx_program.optimize() # Only supported by torch>=2.6.0.
@@ -139,7 +156,11 @@ def export_onnx(model, filepath, verbose=None, input_signature=None, **kwargs):
139156

140157
# Fall back to the TorchScript-based ONNX exporter.
141158
torch.onnx.export(
142-
model, sample_inputs, filepath, verbose=actual_verbose
159+
model,
160+
sample_inputs,
161+
filepath,
162+
verbose=actual_verbose,
163+
opset_version=opset_version,
143164
)
144165
else:
145166
raise NotImplementedError(

keras/src/export/onnx_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,3 +245,27 @@ def build(self, y_shape, x_shape):
245245
)
246246
}
247247
ort_session.run(None, ort_inputs)
248+
249+
@parameterized.named_parameters(named_product(opset_version=[None, 18]))
250+
def test_export_with_opset_version(self, opset_version):
251+
import onnx as onnx_lib
252+
253+
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
254+
model = get_model("sequential")
255+
batch_size = 3 if backend.backend() != "torch" else 1
256+
ref_input = np.random.normal(size=(batch_size, 10))
257+
ref_input = ref_input.astype("float32")
258+
ref_output = model(ref_input)
259+
260+
onnx.export_onnx(
261+
model, temp_filepath, opset_version=opset_version, verbose=True
262+
)
263+
ort_session = onnxruntime.InferenceSession(temp_filepath)
264+
ort_inputs = {
265+
k.name: v for k, v in zip(ort_session.get_inputs(), [ref_input])
266+
}
267+
self.assertAllClose(ref_output, ort_session.run(None, ort_inputs)[0])
268+
269+
if opset_version is not None:
270+
onnx_model = onnx_lib.load(temp_filepath)
271+
self.assertEqual(onnx_model.opset_import[0].version, opset_version)

keras/src/models/model.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -548,17 +548,20 @@ def export(
548548
`tf.TensorSpec`, `backend.KerasTensor`, or backend tensor. If
549549
not provided, it will be automatically computed. Defaults to
550550
`None`.
551-
**kwargs: Additional keyword arguments:
552-
- Specific to the JAX backend and `format="tf_saved_model"`:
553-
- `is_static`: Optional `bool`. Indicates whether `fn` is
554-
static. Set to `False` if `fn` involves state updates
555-
(e.g., RNG seeds and counters).
556-
- `jax2tf_kwargs`: Optional `dict`. Arguments for
557-
`jax2tf.convert`. See the documentation for
558-
[`jax2tf.convert`](
559-
https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).
560-
If `native_serialization` and `polymorphic_shapes` are
561-
not provided, they will be automatically computed.
551+
**kwargs: Additional keyword arguments.
552+
- `is_static`: Optional `bool`. Specific to the JAX backend and
553+
`format="tf_saved_model"`. Indicates whether `fn` is static.
554+
Set to `False` if `fn` involves state updates (e.g., RNG
555+
seeds and counters).
556+
- `jax2tf_kwargs`: Optional `dict`. Specific to the JAX backend
557+
and `format="tf_saved_model"`. Arguments for
558+
`jax2tf.convert`. See the documentation for
559+
[`jax2tf.convert`](
560+
https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).
561+
If `native_serialization` and `polymorphic_shapes` are not
562+
provided, they will be automatically computed.
563+
- `opset_version`: Optional `int`. Specific to `format="onnx"`.
564+
An integer value that specifies the ONNX opset version.
562565
563566
**Note:** This feature is currently supported only with TensorFlow, JAX
564567
and Torch backends.

0 commit comments

Comments
 (0)