Skip to content

Commit 4a3d11b

Browse files
committed
Add cast to same type before equal operator
Signed-off-by: bedapisl <[email protected]>
1 parent 6ec695b commit 4a3d11b

File tree

3 files changed

+30
-2
lines changed

3 files changed

+30
-2
lines changed

tests/test_backend.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3602,6 +3602,23 @@ def test_conv2d_1_kernel_as_input(self):
36023602
[1., 1., 4.]], dtype=np.float32).reshape(_KERNEL3x3)
36033603
self._conv_kernel_as_input_test(x_val, w_val)
36043604

3605+
def test_equal_with_different_parameters(self):
3606+
input_val = np.array([5], dtype=np.int32)
3607+
3608+
def func(input_val):
3609+
tensor = tf.zeros(input_val)
3610+
input_size = tf.size(tensor)
3611+
constant = tf.constant(3, dtype=tf.int32)
3612+
return tf.math.equal(input_size, constant, name="output")
3613+
3614+
feed_dict = {"input:0": input_val}
3615+
input_names_with_port = ["input:0"]
3616+
output_names_with_port = ["output:0"]
3617+
3618+
current_opset = self.config.opset
3619+
self.config.opset = 12
3620+
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port)
3621+
self.config.opset = current_opset
36053622

36063623
if __name__ == '__main__':
36073624
unittest_main()

tf2onnx/onnx_opset/logical.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,16 @@ def _add_cast_to_inputs(graph, node, supported_dtypes, target_dtype):
3434
graph.set_dtype(inp_cast.output[0], target_dtype)
3535

3636

37+
def _add_cast_to_same_type_to_inputs(graph, node):
38+
common_dtype = graph.get_dtype(node.input[0])
39+
40+
for inp in node.input[1:]:
41+
if graph.get_dtype(inp) != common_dtype:
42+
inp_cast = graph.insert_new_node_on_input(node, "Cast", inp, to=common_dtype)
43+
graph.copy_shape(inp, inp_cast.output[0])
44+
graph.set_dtype(inp_cast.output[0], common_dtype)
45+
46+
3747
@tf_op("LogicalNot", onnx_op="Not")
3848
class DirectOp:
3949
@classmethod
@@ -81,7 +91,8 @@ def version_7(cls, ctx, node, **kwargs):
8191

8292
@classmethod
8393
def version_11(cls, ctx, node, **kwargs):
84-
# starting with opset-11, equal supports all types
94+
# starting with opset-11, equal supports all types (but both operands must be of the same type)
95+
_add_cast_to_same_type_to_inputs(ctx, node)
8596
need_not = node.type == "NotEqual"
8697
if need_not:
8798
node.type = "Equal"

tf2onnx/onnx_opset/tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def _wrap_concat_with_cast(ctx, node):
6161
class Size:
6262
@classmethod
6363
def version_1(cls, ctx, node, **kwargs):
64-
pass
64+
ctx.set_dtype(node.output[0], onnx_pb.TensorProto.INT64)
6565

6666

6767
@tf_op("Flatten")

0 commit comments

Comments
 (0)