Skip to content

Commit 276bdea

Browse files
author
Mike Essenmacher
authored
Change Equal 11 for string input (#2149)
* Change Equal 11 for string input * Unify the dtype of all of inputs and add backend test --------- Signed-off-by: Mike Essenmacher <[email protected]>
1 parent aaaea95 commit 276bdea

File tree

2 files changed

+35
-5
lines changed

2 files changed

+35
-5
lines changed

tests/test_backend.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1421,6 +1421,14 @@ def func(x1, x2):
14211421
return tf.identity(mi, name=_TFOUTPUT)
14221422
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2})
14231423

1424+
def test_equal_string(self):
1425+
x_val1 = np.array(['1'], dtype=np.string_)
1426+
x_val2 = np.array(['2'], dtype=np.string_)
1427+
def func(x1, x2):
1428+
mi = tf.equal(x1, x2)
1429+
return tf.identity(mi, name=_TFOUTPUT)
1430+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2})
1431+
14241432
def test_equal(self):
14251433
x_val1 = np.array([4, 2, 4, 1], dtype=np.int32).reshape((2, 2))
14261434
x_val2 = np.array([2, 4, 4, 1], dtype=np.int32).reshape((2, 2))

tf2onnx/onnx_opset/logical.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,21 @@ def _add_cast_to_inputs(graph, node, supported_dtypes, target_dtype):
2929
graph.copy_shape(inp, inp_cast.output[0])
3030
graph.set_dtype(inp_cast.output[0], target_dtype)
3131

32-
33-
def _add_cast_to_same_type_to_inputs(graph, node):
32+
def _add_cast_to_same_type_to_inputs(graph, node, supported_dtypes, target_dtype):
3433
common_dtype = graph.get_dtype(node.input[0])
34+
if common_dtype not in supported_dtypes:
35+
common_dtype = target_dtype
3536

36-
for inp in node.input[1:]:
37+
for inp in node.input:
3738
if graph.get_dtype(inp) != common_dtype:
3839
inp_cast = graph.insert_new_node_on_input(node, "Cast", inp, to=common_dtype)
3940
graph.copy_shape(inp, inp_cast.output[0])
4041
graph.set_dtype(inp_cast.output[0], common_dtype)
42+
if graph.is_const(inp) and graph.get_tensor_value(inp) == '':
43+
# Convert '' string constant to -1 int
44+
# https://github.com/tensorflow/tensorflow/blob/4e7f0185c70faf35e12acbfe381a729d1e6cc38c/tensorflow/python/feature_column/feature_column.py#L2286
45+
const_node = graph.get_node_by_output(inp)
46+
const_node.set_tensor_value(utils.np.array(-1))
4147

4248

4349
@tf_op("LogicalNot", onnx_op="Not")
@@ -92,8 +98,24 @@ def version_7(cls, ctx, node, **kwargs):
9298

9399
@classmethod
94100
def version_11(cls, ctx, node, **kwargs):
95-
# starting with opset-11, equal supports all types (but both operands must be of the same type)
96-
_add_cast_to_same_type_to_inputs(ctx, node)
101+
# starting with opset-11, equal supports all numerical types (but both operands must be of the same type)
102+
# string type is not supported
103+
supported_dtypes = [
104+
TensorProto.BOOL,
105+
TensorProto.DOUBLE,
106+
TensorProto.FLOAT,
107+
TensorProto.FLOAT16,
108+
TensorProto.INT8,
109+
TensorProto.INT16,
110+
TensorProto.INT32,
111+
TensorProto.INT64,
112+
TensorProto.UINT8,
113+
TensorProto.UINT16,
114+
TensorProto.UINT32,
115+
TensorProto.UINT64
116+
]
117+
target_dtype = TensorProto.INT32
118+
_add_cast_to_same_type_to_inputs(ctx, node, supported_dtypes, target_dtype)
97119
need_not = node.type == "NotEqual"
98120
if need_not:
99121
node.type = "Equal"

0 commit comments

Comments
 (0)