Skip to content

Commit 2cf8854

Browse files
Merge pull request #1089 from bedapisl/fix_equal
Add cast to same type before equal operator
2 parents b51df2f + 25ae1e1 commit 2cf8854

File tree

3 files changed

+31
-2
lines changed

3 files changed

+31
-2
lines changed

tests/test_backend.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3625,6 +3625,24 @@ def test_conv2d_1_kernel_as_input(self):
36253625
[1., 1., 4.]], dtype=np.float32).reshape(_KERNEL3x3)
36263626
self._conv_kernel_as_input_test(x_val, w_val)
36273627

3628+
def test_equal_with_different_parameters(self):
3629+
input_val = np.array([5], dtype=np.int32)
3630+
3631+
def func(input_val):
3632+
tensor = tf.zeros(input_val)
3633+
input_size = tf.size(tensor)
3634+
constant = tf.constant(3, dtype=tf.int32)
3635+
return tf.math.equal(input_size, constant, name="output")
3636+
3637+
feed_dict = {"input:0": input_val}
3638+
input_names_with_port = ["input:0"]
3639+
output_names_with_port = ["output:0"]
3640+
3641+
current_opset = self.config.opset
3642+
self.config.opset = 12
3643+
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port)
3644+
self.config.opset = current_opset
3645+
36283646
@check_tf_min_version("1.14")
36293647
def test_rfft_ops(self):
36303648

tf2onnx/onnx_opset/logical.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,16 @@ def _add_cast_to_inputs(graph, node, supported_dtypes, target_dtype):
3434
graph.set_dtype(inp_cast.output[0], target_dtype)
3535

3636

37+
def _add_cast_to_same_type_to_inputs(graph, node):
38+
common_dtype = graph.get_dtype(node.input[0])
39+
40+
for inp in node.input[1:]:
41+
if graph.get_dtype(inp) != common_dtype:
42+
inp_cast = graph.insert_new_node_on_input(node, "Cast", inp, to=common_dtype)
43+
graph.copy_shape(inp, inp_cast.output[0])
44+
graph.set_dtype(inp_cast.output[0], common_dtype)
45+
46+
3747
@tf_op("LogicalNot", onnx_op="Not")
3848
class DirectOp:
3949
@classmethod
@@ -81,7 +91,8 @@ def version_7(cls, ctx, node, **kwargs):
8191

8292
@classmethod
8393
def version_11(cls, ctx, node, **kwargs):
84-
# starting with opset-11, equal supports all types
94+
# starting with opset-11, equal supports all types (but both operands must be of the same type)
95+
_add_cast_to_same_type_to_inputs(ctx, node)
8596
need_not = node.type == "NotEqual"
8697
if need_not:
8798
node.type = "Equal"

tf2onnx/onnx_opset/tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def _wrap_concat_with_cast(ctx, node):
6161
class Size:
6262
@classmethod
6363
def version_1(cls, ctx, node, **kwargs):
64-
pass
64+
ctx.set_dtype(node.output[0], onnx_pb.TensorProto.INT64)
6565

6666

6767
@tf_op("Flatten")

0 commit comments

Comments
 (0)