Skip to content

Commit b620e4d

Browse files
authored
Fix Torch ONNX export not respecting InputSpec.name (#21646)
* fix(torch-export): ONNX export not respecting InputSpec.name * Update onnx.py * Pythonifying the code * api-gen * idek * Inputspec Attribute Error * Inputspec Attribute Error * InputSpec missing dtype * layer not built yet
1 parent 779fd0b commit b620e4d

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

keras/src/export/onnx.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ def export_onnx(
8080
"The model provided has never called. "
8181
"It must be called at least once before export."
8282
)
83+
input_names = [
84+
getattr(spec, "name", None) or f"input_{i}"
85+
for i, spec in enumerate(input_signature)
86+
]
8387

8488
if backend.backend() in ("tensorflow", "jax"):
8589
from keras.src.utils.module_utils import tf2onnx
@@ -143,6 +147,7 @@ def export_onnx(
143147
sample_inputs,
144148
verbose=actual_verbose,
145149
opset_version=opset_version,
150+
input_names=input_names,
146151
dynamo=True,
147152
)
148153
if hasattr(onnx_program, "optimize"):
@@ -161,6 +166,7 @@ def export_onnx(
161166
filepath,
162167
verbose=actual_verbose,
163168
opset_version=opset_version,
169+
input_names=input_names,
164170
)
165171
else:
166172
raise NotImplementedError(

keras/src/export/onnx_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from keras.src import testing
1515
from keras.src import tree
1616
from keras.src.export import onnx
17+
from keras.src.layers.input_spec import InputSpec as InputSpec
1718
from keras.src.saving import saving_lib
1819
from keras.src.testing.test_utils import named_product
1920

@@ -269,3 +270,31 @@ def test_export_with_opset_version(self, opset_version):
269270
if opset_version is not None:
270271
onnx_model = onnx_lib.load(temp_filepath)
271272
self.assertEqual(onnx_model.opset_import[0].version, opset_version)
273+
274+
def test_export_with_input_names(self):
275+
"""Test ONNX export uses InputSpec.name for input names."""
276+
import onnx as onnx_lib
277+
278+
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
279+
model = get_model("sequential")
280+
batch_size = 3 if backend.backend() != "torch" else 1
281+
ref_input = np.random.normal(size=(batch_size, 10)).astype("float32")
282+
ref_output = model(ref_input)
283+
284+
# Test with custom input name
285+
input_spec = [
286+
InputSpec(
287+
name="custom_input", shape=(batch_size, 10), dtype="float32"
288+
)
289+
]
290+
onnx.export_onnx(model, temp_filepath, input_signature=input_spec)
291+
292+
onnx_model = onnx_lib.load(temp_filepath)
293+
input_names = [input.name for input in onnx_model.graph.input]
294+
self.assertIn("custom_input", input_names)
295+
296+
ort_session = onnxruntime.InferenceSession(temp_filepath)
297+
ort_inputs = {
298+
k.name: v for k, v in zip(ort_session.get_inputs(), [ref_input])
299+
}
300+
self.assertAllClose(ref_output, ort_session.run(None, ort_inputs)[0])

0 commit comments

Comments
 (0)