Skip to content

Commit 5330854

Browse files
Unit test for from_function with non-tensor inputs in spec (#1492)
* Filter structured inputs to only include TensorSpecs Signed-off-by: Tom Wildenhain <[email protected]> * Unit test for from_function with non-tensor inputs in spec Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 0d71be8 commit 5330854

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

tests/test_api.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,35 @@ def func(x, y):
100100
oy = self.run_onnxruntime(output_path, {"x": x, "y": y}, output_names)
101101
self.assertAllClose(ky, oy[0], rtol=0.3, atol=0.1)
102102

103+
@check_tf_min_version("2.0")
104+
def test_function_non_tensor_inputs(self):
105+
class Foo:
106+
a = 42
107+
108+
@tf.function
109+
def func(foo, a, x, b, w):
110+
if a:
111+
return x + foo.a + b / w
112+
return x + b
113+
114+
output_path = os.path.join(self.test_data_directory, "model.onnx")
115+
x = np.arange(20).reshape([2, 10]).astype(np.float32)
116+
w = np.arange(10).reshape([10]).astype(np.float32)
117+
118+
res_tf = func(Foo(), True, x, 123, w)
119+
spec = (
120+
Foo(),
121+
True,
122+
tf.TensorSpec((2, None), tf.float32, name="x"),
123+
123,
124+
tf.TensorSpec((None), tf.float32, name="w")
125+
)
126+
model_proto, _ = tf2onnx.convert.from_function(func, input_signature=spec,
127+
opset=self.config.opset, output_path=output_path)
128+
output_names = [n.name for n in model_proto.graph.output]
129+
res_onnx = self.run_onnxruntime(output_path, {"x": x, "w": w}, output_names)
130+
self.assertAllClose(res_tf, res_onnx[0], rtol=1e-5, atol=1e-5)
131+
103132
@check_tf_min_version("1.15")
104133
def _test_graphdef(self):
105134
def func(x, y):

0 commit comments

Comments
 (0)