|
14 | 14 | from keras.src import testing
|
15 | 15 | from keras.src import tree
|
16 | 16 | from keras.src.export import onnx
|
| 17 | +from keras.src.layers.input_spec import InputSpec as InputSpec |
17 | 18 | from keras.src.saving import saving_lib
|
18 | 19 | from keras.src.testing.test_utils import named_product
|
19 | 20 |
|
@@ -269,3 +270,31 @@ def test_export_with_opset_version(self, opset_version):
|
269 | 270 | if opset_version is not None:
|
270 | 271 | onnx_model = onnx_lib.load(temp_filepath)
|
271 | 272 | self.assertEqual(onnx_model.opset_import[0].version, opset_version)
|
| 273 | + |
| 274 | + def test_export_with_input_names(self): |
| 275 | + """Test ONNX export uses InputSpec.name for input names.""" |
| 276 | + import onnx as onnx_lib |
| 277 | + |
| 278 | + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") |
| 279 | + model = get_model("sequential") |
| 280 | + batch_size = 3 if backend.backend() != "torch" else 1 |
| 281 | + ref_input = np.random.normal(size=(batch_size, 10)).astype("float32") |
| 282 | + ref_output = model(ref_input) |
| 283 | + |
| 284 | + # Test with custom input name |
| 285 | + input_spec = [ |
| 286 | + InputSpec( |
| 287 | + name="custom_input", shape=(batch_size, 10), dtype="float32" |
| 288 | + ) |
| 289 | + ] |
| 290 | + onnx.export_onnx(model, temp_filepath, input_signature=input_spec) |
| 291 | + |
| 292 | + onnx_model = onnx_lib.load(temp_filepath) |
| 293 | + input_names = [input.name for input in onnx_model.graph.input] |
| 294 | + self.assertIn("custom_input", input_names) |
| 295 | + |
| 296 | + ort_session = onnxruntime.InferenceSession(temp_filepath) |
| 297 | + ort_inputs = { |
| 298 | + k.name: v for k, v in zip(ort_session.get_inputs(), [ref_input]) |
| 299 | + } |
| 300 | + self.assertAllClose(ref_output, ort_session.run(None, ort_inputs)[0]) |
0 commit comments