Skip to content

Commit 508f6be

Browse files
Fix Where/Select op for bool dtypes (#1673)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent f0d9a4d commit 508f6be

File tree

2 files changed

+27
-5
lines changed

2 files changed

+27
-5
lines changed

tests/test_backend.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3318,6 +3318,18 @@ def func(x):
33183318
return tf.identity(picks, name=_TFOUTPUT)
33193319
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
33203320

3321+
@check_opset_min_version(7, "GreaterEqual")
3322+
def test_where_bool(self):
3323+
x_val = np.array([1, 2, -3, 4, -5], dtype=np.float32)
3324+
true_result = np.array([True, False, True, False, True],
3325+
dtype=np.bool)
3326+
false_result = np.array([False, True, False, True, True],
3327+
dtype=np.bool)
3328+
def func(x):
3329+
picks = tf.where(x > -1, true_result, false_result)
3330+
return tf.identity(picks, name=_TFOUTPUT)
3331+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
3332+
33213333
@check_opset_min_version(7, "GreaterEqual")
33223334
#@check_target("rs6", "onnxruntime Where type limitation")
33233335
def test_where_int32(self):

tf2onnx/onnx_opset/controlflow.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -145,13 +145,21 @@ def version_7(cls, ctx, node, **kwargs):
145145
utils.make_sure(len(node.input) > 1, "Select with only condition is not supported.")
146146
dtype = ctx.get_dtype(node.output[0])
147147
utils.make_sure(dtype != TensorProto.STRING, "Select with dtype string requires opset 9")
148+
tmp_dtype = dtype
149+
if tmp_dtype == TensorProto.BOOL:
150+
tmp_dtype = TensorProto.INT32
148151

149152
cond_shape = ctx.get_shape(node.input[0])
150153
input_shape = ctx.get_shape(node.input[1])
151154
if input_shape is None:
152155
input_shape = ctx.get_shape(node.input[2])
153156
input_rank = len(input_shape) if input_shape is not None else None
154157
cond_rank = len(cond_shape) if cond_shape is not None else None
158+
true_inp = node.input[1]
159+
false_inp = node.input[2]
160+
if tmp_dtype != dtype:
161+
true_inp = ctx.make_node("Cast", [true_inp], op_name_scope=node.name, attr={"to": tmp_dtype}).output[0]
162+
false_inp = ctx.make_node("Cast", [false_inp], op_name_scope=node.name, attr={"to": tmp_dtype}).output[0]
155163
# if cond shape is 1-dimensional while input has higher rank, need to be reshaped to broadcast
156164
if node.type == "Select" and cond_rank == 1 and input_rank != 1:
157165
utils.make_sure(input_rank is not None, "input_rank unknown and cond_rank == 1")
@@ -161,18 +169,20 @@ def version_7(cls, ctx, node, **kwargs):
161169
ctx.replace_input(node, node.input[0], reshape.output[0], 0)
162170

163171
positive_cast = ctx.make_node("Cast", [node.input[0]], name=utils.make_name(node.name),
164-
attr={"to": dtype})
172+
attr={"to": tmp_dtype})
165173
negative = ctx.make_node("Not", [node.input[0]], name=utils.make_name(node.name))
166174
negative_cast = ctx.make_node("Cast", [negative.output[0]], name=utils.make_name(node.name),
167-
attr={"to": dtype})
168-
multiply_1 = ctx.make_node("Mul", [positive_cast.output[0], node.input[1]], name=utils.make_name(node.name))
169-
multiply_2 = ctx.make_node("Mul", [node.input[2], negative_cast.output[0]], name=utils.make_name(node.name))
175+
attr={"to": tmp_dtype})
176+
multiply_1 = ctx.make_node("Mul", [positive_cast.output[0], true_inp], name=utils.make_name(node.name))
177+
multiply_2 = ctx.make_node("Mul", [false_inp, negative_cast.output[0]], name=utils.make_name(node.name))
170178
add_name = node.name
171179
add_out = node.output
172180
shape = ctx.get_shape(node.output[0])
173181
ctx.remove_node(node.name)
174182
ctx.make_node("Add", [multiply_1.output[0], multiply_2.output[0]], outputs=add_out, name=add_name,
175-
dtypes=[dtype], shapes=[shape])
183+
dtypes=[tmp_dtype], shapes=[shape])
184+
if tmp_dtype != dtype:
185+
ctx.insert_new_node_on_output("Cast", node.output[0], to=dtype)
176186

177187
@classmethod
178188
def version_9(cls, ctx, node, **kwargs):

0 commit comments

Comments
 (0)