Skip to content

Commit ce08477

Browse files
author
wayuanho
committed
support Select with Where of opset 9
1 parent eb31fc3 commit ce08477

File tree

4 files changed

+86
-27
lines changed

4 files changed

+86
-27
lines changed

tests/test_backend.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1548,6 +1548,28 @@ def test_reverse_sequence_time_major(self):
15481548

15491549
@check_opset_min_version(8, "where")
15501550
def test_where(self):
1551+
x_val = np.array([1, 2, -3, 4, -5, -6, -7, 8, 9, 0], dtype=np.float32)
1552+
true_result = np.array([111, 222, 333, 444, 555, 666, 777, 888, 999, 1000],
1553+
dtype=np.float32)
1554+
false_result = np.array([-111, -222, -333, -444, -555, -666, -777, -888, -999, -1000],
1555+
dtype=np.float32)
1556+
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
1557+
picks = tf.where(x > -1, true_result, false_result)
1558+
_ = tf.identity(picks, name=_TFOUTPUT)
1559+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1560+
1561+
tf.reset_default_graph()
1562+
x_val = np.array(1, dtype=np.float32)
1563+
true_result = np.array(100, dtype=np.float32)
1564+
false_result = np.array(-111, dtype=np.float32)
1565+
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
1566+
picks = tf.where(x > -1, true_result, false_result)
1567+
_ = tf.identity(picks, name=_TFOUTPUT)
1568+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1569+
1570+
@check_opset_min_version(8, "where")
1571+
@check_target("rs6", "onnxruntime Where type limitation")
1572+
def test_where_int32(self):
15511573
x_val = np.array([1, 2, -3, 4, -5, -6, -7, 8, 9, 0], dtype=np.int32)
15521574
true_result = np.array([111, 222, 333, 444, 555, 666, 777, 888, 999, 1000],
15531575
dtype=np.int32)
@@ -1560,59 +1582,59 @@ def test_where(self):
15601582

15611583
@check_opset_min_version(8, "where")
15621584
def test_where_with_two_rank_input(self):
1563-
x_val = np.array([1, 2, -3, 4, -5, -6, -7, 8, 9, 0], dtype=np.int32)
1585+
x_val = np.array([1, 2, -3, 4, -5, -6, -7, 8, 9, 0], dtype=np.float32)
15641586
true_result = np.array([[111, 111], [222, 222], [333, 333], [444, 444], [555, 555],
15651587
[666, 666], [777, 777], [888, 888], [999, 999], [1000, 1000]],
1566-
dtype=np.int32)
1588+
dtype=np.float32)
15671589
false_result = np.array([[-111, -111], [-222, -222], [-333, -333], [-444, -444],
15681590
[-555, -555], [-666, -666], [-777, -777], [-888, -888],
15691591
[-999, -999], [-1000, -1000]],
1570-
dtype=np.int32)
1571-
x = tf.placeholder(tf.int32, [None], name=_TFINPUT)
1592+
dtype=np.float32)
1593+
x = tf.placeholder(tf.float32, [None], name=_TFINPUT)
15721594
picks = tf.where(tf.greater_equal(x, 0), true_result, false_result)
15731595
_ = tf.identity(picks, name=_TFOUTPUT)
15741596

15751597
self._run_test_case([_OUTPUT], {_INPUT: x_val})
15761598

15771599
@check_opset_min_version(8, "where")
15781600
def test_where_with_two_rank_condition(self):
1579-
x_val = np.array([[1, 2, -3, 4, -5, -6, -7, 8, 9, 0]], dtype=np.int32)
1601+
x_val = np.array([[1, 2, -3, 4, -5, -6, -7, 8, 9, 0]], dtype=np.float32)
15801602
true_result = np.array([[111, 222, 333, 444, 555, 666, 777, 888, 999, 1000]],
1581-
dtype=np.int32)
1603+
dtype=np.float32)
15821604
false_result = np.array([[-111, -222, -333, -444, -555, -666, -777, -888, -999, -1000]],
1583-
dtype=np.int32)
1584-
x = tf.placeholder(tf.int32, [1, 10], name=_TFINPUT)
1605+
dtype=np.float32)
1606+
x = tf.placeholder(tf.float32, [1, 10], name=_TFINPUT)
15851607
picks = tf.where(tf.greater_equal(x, 0), true_result, false_result)
15861608
_ = tf.identity(picks, name=_TFOUTPUT)
15871609

15881610
self._run_test_case([_OUTPUT], {_INPUT: x_val})
15891611

15901612
@check_opset_min_version(8, "where")
15911613
def test_where_with_three_rank_condition(self):
1592-
x_val = np.array([[[1, 2, -3, 4, -5, -6, -7, 8, 9, 0]]], dtype=np.int32)
1614+
x_val = np.array([[[1, 2, -3, 4, -5, -6, -7, 8, 9, 0]]], dtype=np.float32)
15931615
true_result = np.array([[[111, 222, 333, 444, 555, 666, 777, 888, 999, 1000]]],
1594-
dtype=np.int32)
1616+
dtype=np.float32)
15951617
false_result = np.array([[[-111, -222, -333, -444, -555, -666, -777, -888, -999, -1000]]],
1596-
dtype=np.int32)
1597-
x = tf.placeholder(tf.int32, [1, 1, 10], name=_TFINPUT)
1618+
dtype=np.float32)
1619+
x = tf.placeholder(tf.float32, [1, 1, 10], name=_TFINPUT)
15981620
picks = tf.where(tf.greater_equal(x, 0), true_result, false_result)
15991621
_ = tf.identity(picks, name=_TFOUTPUT)
16001622

16011623
self._run_test_case([_OUTPUT], {_INPUT: x_val})
16021624

16031625
@check_opset_min_version(8, "where")
16041626
def test_where_scalar(self):
1605-
x_val = np.array(6, dtype=np.int32)
1627+
x_val = np.array(6, dtype=np.float32)
16061628
true_result = np.array([111, 222, 333, 444, 555, 666, 777, 888, 999, 1000],
1607-
dtype=np.int32)
1629+
dtype=np.float32)
16081630
false_result = np.array([-111, -222, -333, -444, -555, -666, -777, -888, -999, -1000],
1609-
dtype=np.int32)
1610-
x = tf.placeholder(tf.int32, [], name=_TFINPUT)
1631+
dtype=np.float32)
1632+
x = tf.placeholder(tf.float32, [], name=_TFINPUT)
16111633
picks = tf.where(tf.greater_equal(x, 0), true_result, false_result)
16121634
_ = tf.identity(picks, name=_TFOUTPUT)
16131635
self._run_test_case([_OUTPUT], {_INPUT: x_val})
16141636

1615-
@check_opset_min_version(9, "where")
1637+
@check_opset_min_version(9, "NonZero")
16161638
@check_target("rs6", "onnxruntime Transpose type limitation")
16171639
def test_where_with_cond_only(self):
16181640
for np_type, tf_type in [(np.int32, tf.int32), (np.float32, tf.float32)]:

tf2onnx/function/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from tf2onnx.function.lstm_block_cell import lstm_block_cell_op
1111
from tf2onnx.function.matrixbandpart import matrixbandpart_op
1212
from tf2onnx.function.range import range_op7
13-
from tf2onnx.function.select import select_op8
13+
from tf2onnx.function.select import select_op8, select_op9
1414
from tf2onnx.function.softmax_cross_entropy_with_logits import softmax_cross_entropy_with_logits_op7
1515
from tf2onnx.function.softmax_cross_entropy_with_logits import sparse_softmax_cross_entropy_with_logits_op7
1616
from tf2onnx.function.softmax_cross_entropy_with_logits import sparse_softmax_cross_entropy_with_logits_op9
@@ -21,6 +21,7 @@
2121
"matrixbandpart_op",
2222
"range_op7",
2323
"select_op8",
24+
"select_op9",
2425
"softmax_cross_entropy_with_logits_op7",
2526
"sparse_softmax_cross_entropy_with_logits_op7",
2627
"sparse_softmax_cross_entropy_with_logits_op9",

tf2onnx/function/select.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,26 @@
1313
# pylint: disable=unused-argument,missing-docstring
1414

1515

16+
def select_op9(ctx, node, name, args):
17+
# T output = Select(bool condition, T x, T y)
18+
# T1 output = Where(bool condition, T1 x, T1 y)
19+
# NOTE: condition can be 1-dimension in tensorflow, while in onnx,
20+
# it should be broadcastable with other two inputs
21+
cond_shape = ctx.get_shape(node.input[0])
22+
make_sure(cond_shape is not None, "shape of {} is None".format(node.input[0]))
23+
input_shape = ctx.get_shape(node.input[1])
24+
if input_shape is None:
25+
input_shape = ctx.get_shape(node.input[2])
26+
make_sure(input_shape is not None, "input shape of {} is None".format(node.name))
27+
input_rank = len(input_shape)
28+
# if cond shape is 1-dimensional while input has higher rank, need to be reshaped to broadcast
29+
if len(cond_shape) == 1 and input_rank > 1:
30+
broadcast_shape = [cond_shape[0]] + [1] * (input_rank - 1)
31+
shape_const = ctx.make_const(utils.make_name(name), np.array(broadcast_shape, dtype=np.int64))
32+
reshape = ctx.make_node("Reshape", [node.input[0], shape_const.output[0]])
33+
ctx.replace_input(node, node.input[0], reshape.output[0])
34+
35+
1636
def select_op8(ctx, node, name, args):
1737
# T output = Select(bool condition, T x, T y)
1838
# V v_final_and_scan_outputs = Loop(int64 M, B cond, V v_initial)

tf2onnx/tfonnx.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,20 +1171,25 @@ def minmax_op(ctx, node, name, args):
11711171
onnx_pb.TensorProto.DOUBLE
11721172
]
11731173
target_dtype = onnx_pb.TensorProto.FLOAT
1174+
need_cast = False
11741175
for inp in node.input:
11751176
dtype = ctx.get_dtype(inp)
1176-
utils.make_sure(dtype, "dtype of {} is None".format(inp))
1177+
utils.make_sure(dtype is not None, "dtype of {} is None".format(inp))
11771178
if dtype not in supported_dtypes:
11781179
inp_cast = ctx.insert_new_node_on_input(node, "Cast", inp, to=target_dtype)
11791180
ctx.copy_shape(inp, inp_cast.output[0])
11801181
ctx.set_dtype(inp_cast.output[0], target_dtype)
1181-
origin_dtype = ctx.get_dtype(node.output[0])
1182-
utils.make_sure(origin_dtype is not None, "dtype of {} is None".format(node.output[0]))
1183-
ctx.set_dtype(node.output[0], target_dtype)
1184-
cast_name = utils.make_name(name)
1185-
cast_node = ctx.insert_new_node_on_output("Cast", node.output[0], name=cast_name, to=origin_dtype)
1186-
to_replace = [n for n in ctx.get_nodes() if n != cast_node]
1187-
ctx.replace_all_inputs(to_replace, node.output[0], cast_node.output[0])
1182+
need_cast = True
1183+
if need_cast:
1184+
origin_dtype = ctx.get_dtype(node.output[0])
1185+
utils.make_sure(origin_dtype is not None, "dtype of {} is None".format(node.output[0]))
1186+
ctx.set_dtype(node.output[0], target_dtype)
1187+
cast_name = utils.make_name(name)
1188+
cast_node = ctx.insert_new_node_on_output("Cast", node.output[0], name=cast_name, to=origin_dtype)
1189+
ctx.set_dtype(cast_node.output[0], origin_dtype)
1190+
ctx.copy_shape(node.output[0], cast_node.output[0])
1191+
to_replace = [n for n in ctx.get_nodes() if n != cast_node]
1192+
ctx.replace_all_inputs(to_replace, node.output[0], cast_node.output[0])
11881193

11891194
shapeo = ctx.get_shape(node.output[0])
11901195
needs_broadcast_op = []
@@ -1912,6 +1917,7 @@ def where_op(ctx, node, name, args):
19121917
"ResizeBilinear": (upsample_op9, ["Upsample", "linear"]),
19131918
"ResizeNearestNeighbor": (upsample_op9, ["Upsample", "nearest"]),
19141919
"ReverseSequence": (reverse_op9, []),
1920+
"Select": (select_op9, ["Where"]),
19151921
"Sign": (sign_op9, []),
19161922
"Sinh": (direct_op, []),
19171923
"SparseSoftmaxCrossEntropyWithLogits": (sparse_softmax_cross_entropy_with_logits_op9, []),
@@ -2209,6 +2215,7 @@ def rewrite_incomplete_type_support(g, ops, impacted_ops):
22092215
"""
22102216
ignored_input_index = {
22112217
"Tile": [1], # Tile's second input can only be int64
2218+
"Where": [0], # Where's first input is bool
22122219
}
22132220
new_ops = []
22142221
org_ops = [n for n in ops]
@@ -2264,7 +2271,16 @@ def rewrite_incomplete_type_support_rs5(g, ops):
22642271

22652272

22662273
def rewrite_incomplete_type_support_rs6(g, ops):
2267-
return rewrite_incomplete_type_support(g, ops, ["Div", "IsNaN", "ReduceSum", "Slice", "Split", "Tile", "Transpose"])
2274+
return rewrite_incomplete_type_support(g, ops, [
2275+
"Div",
2276+
"IsNaN",
2277+
"ReduceSum",
2278+
"Slice",
2279+
"Split",
2280+
"Tile",
2281+
"Transpose",
2282+
"Where"
2283+
])
22682284

22692285

22702286
def rewrite_conv2d_with_pad(g, ops):

0 commit comments

Comments
 (0)