Skip to content

Commit 510b8d2

Browse files
authored
Merge pull request #351 from nbcsm/revert
Revert "replace conversion logic of "select" with a simpler one"
2 parents d2ece0c + dd28625 commit 510b8d2

File tree

4 files changed

+235
-40
lines changed

4 files changed

+235
-40
lines changed

tests/test_backend.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1446,7 +1446,8 @@ def test_reverse_sequence_time_major(self):
14461446
_ = tf.identity(x_, name=_TFOUTPUT)
14471447
self._run_test_case([_OUTPUT], {_INPUT: x_val})
14481448

1449-
@check_opset_min_version(7, "where")
1449+
# @unittest.skipIf(OPSET < 8, "supported with opset 8 or better")
1450+
@unittest.skip("FIXME: the newest onnxruntime wheel hasn't been published to PYPI, so Select op is not supported")
14501451
def test_where(self):
14511452
x_val = np.array([1, 2, -3, 4, -5, -6, -7, 8, 9, 0], dtype=np.int32)
14521453
true_result = np.array([111, 222, 333, 444, 555, 666, 777, 888, 999, 1000],
@@ -1458,7 +1459,7 @@ def test_where(self):
14581459
_ = tf.identity(picks, name=_TFOUTPUT)
14591460
self._run_test_case([_OUTPUT], {_INPUT: x_val})
14601461

1461-
@check_opset_min_version(7, "where")
1462+
@check_opset_min_version(8, "where")
14621463
def test_where_with_two_rank_input(self):
14631464
x_val = np.array([1, 2, -3, 4, -5, -6, -7, 8, 9, 0], dtype=np.int32)
14641465
true_result = np.array([[111, 111], [222, 222], [333, 333], [444, 444], [555, 555],
@@ -1474,7 +1475,7 @@ def test_where_with_two_rank_input(self):
14741475

14751476
self._run_test_case([_OUTPUT], {_INPUT: x_val})
14761477

1477-
@check_opset_min_version(7, "where")
1478+
@check_opset_min_version(8, "where")
14781479
def test_where_with_two_rank_condition(self):
14791480
x_val = np.array([[1, 2, -3, 4, -5, -6, -7, 8, 9, 0]], dtype=np.int32)
14801481
true_result = np.array([[111, 222, 333, 444, 555, 666, 777, 888, 999, 1000]],
@@ -1487,7 +1488,7 @@ def test_where_with_two_rank_condition(self):
14871488

14881489
self._run_test_case([_OUTPUT], {_INPUT: x_val})
14891490

1490-
@check_opset_min_version(7, "where")
1491+
@check_opset_min_version(8, "where")
14911492
def test_where_with_three_rank_condition(self):
14921493
x_val = np.array([[[1, 2, -3, 4, -5, -6, -7, 8, 9, 0]]], dtype=np.int32)
14931494
true_result = np.array([[[111, 222, 333, 444, 555, 666, 777, 888, 999, 1000]]],
@@ -1500,7 +1501,7 @@ def test_where_with_three_rank_condition(self):
15001501

15011502
self._run_test_case([_OUTPUT], {_INPUT: x_val})
15021503

1503-
@check_opset_min_version(7, "where")
1504+
@check_opset_min_version(8, "where")
15041505
def test_where_scalar(self):
15051506
x_val = np.array(6, dtype=np.int32)
15061507
true_result = np.array([111, 222, 333, 444, 555, 666, 777, 888, 999, 1000],

tf2onnx/function/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from .gathernd import gathernd_op
1010
from .matrixbandpart import matrixbandpart_op
1111
from .range import range_op7
12-
from .select import select_op7
12+
from .select import select_op8
1313
from .sparse_softmax_cross_entropy_with_logits import sparse_softmax_cross_entropy_with_logits_op
1414

15-
__all__ = ["gathernd_op", "matrixbandpart_op", "range_op7", "select_op7", "sparse_softmax_cross_entropy_with_logits_op"]
15+
__all__ = ["gathernd_op", "matrixbandpart_op", "range_op7", "select_op8", "sparse_softmax_cross_entropy_with_logits_op"]

tf2onnx/function/select.py

Lines changed: 226 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,42 +4,236 @@
44
"""
55
tf2onnx.tf2onnx - select op conversion
66
"""
7-
7+
import numpy as np
8+
from onnx.onnx_pb import TensorProto
89
from tf2onnx import utils
10+
from tf2onnx.utils import port_name, make_sure
911

1012

1113
# pylint: disable=unused-argument,missing-docstring
1214

1315

14-
def select_op7(ctx, node, name, args):
16+
def select_op8(ctx, node, name, args):
1517
# T output = Select(bool condition, T x, T y)
16-
select_x_dtype = ctx.get_dtype(node.input[1])
17-
select_x_shape = ctx.get_shape(node.input[1])
18-
cond = node.inputs[0]
19-
cond_shape = ctx.get_shape(cond.output[0])
20-
utils.make_sure(select_x_shape is not None and cond_shape is not None, "rank of inputs are needed")
21-
22-
added_nodes = []
23-
true_mask = ctx.make_node("Cast", cond.output, attr={"to": select_x_dtype})
24-
cond_not = ctx.make_node("Not", cond.output)
25-
false_mask = ctx.make_node("Cast", cond_not.output, attr={"to": select_x_dtype})
26-
added_nodes.extend([true_mask, cond_not, false_mask])
27-
# the broadcasting rule of select is different with common rule.
28-
# for example, shape of input_x is (10), shape of input_y is (10, 2) then common broadcasting rule will fail
29-
# while in tf "select", input_x will become (10, 1) and repeat at last dimension
30-
# so reshape node is inserted here
31-
unsqueeze_dim_num = len(select_x_shape) - len(cond_shape)
32-
utils.make_sure(unsqueeze_dim_num >= 0, "dim of select_x must not less than cond")
33-
if unsqueeze_dim_num != 0:
34-
unsqueeze_dim_start = len(select_x_shape)
35-
axes = range(unsqueeze_dim_start-1, unsqueeze_dim_start+unsqueeze_dim_num-1)
36-
true_mask = ctx.make_node("Unsqueeze", true_mask.output, attr={"axes": axes})
37-
false_mask = ctx.make_node("Unsqueeze", false_mask.output, attr={"axes": axes})
38-
added_nodes.extend([true_mask, false_mask])
39-
40-
select_from_true = ctx.make_node("Mul", [true_mask.output[0], node.input[1]])
41-
select_from_false = ctx.make_node("Mul", [false_mask.output[0], node.input[2]])
42-
res = ctx.make_node("Add", [select_from_true.output[0], select_from_false.output[0]],
43-
name=node.name, outputs=node.output,
44-
shapes=[ctx.get_shape(node.output[0])], dtypes=[ctx.get_dtype(node.output[0])])
45-
return [*added_nodes, select_from_true, select_from_false, res]
18+
# V v_final_and_scan_outputs = Loop(int64 M, B cond, V v_initial)
19+
utils.make_sure(len(node.input) > 1, "Select with only condition is not supported.")
20+
21+
nodes = []
22+
true_data_type = ctx.get_dtype(node.input[1])
23+
true_data_shape = ctx.get_shape(node.input[1])
24+
make_sure(true_data_type is not None, "select true data dtype cannot be None")
25+
make_sure(true_data_shape is not None, "select true data shape cannot be None")
26+
27+
condition_shape = ctx.get_shape(node.input[0])
28+
utils.make_sure(condition_shape is not None, "condition shape is None")
29+
rank = len(condition_shape)
30+
31+
utils.make_sure(rank >= 0, "rank should be >= 0")
32+
val_output_id = None
33+
if rank > 0:
34+
# create nodes getting shape of condition
35+
shape_node_output_shape = [rank]
36+
shape_node = ctx.make_node("Shape", [node.input[0]], op_name_scope=node.name,
37+
shapes=[shape_node_output_shape], dtypes=[TensorProto.INT64])
38+
nodes.append(shape_node)
39+
40+
# todo(pengwa), move those leveraging rewrite_incomplete_type_support_onnxruntime after shape inferencing
41+
# bug is fixed.
42+
# workaround: onnxruntime does not support Split-2, add cases before and after.
43+
target_dtype = TensorProto.FLOAT
44+
shape_f_node = ctx.make_node("Cast", [shape_node.output[0]], attr={"to": target_dtype},
45+
shapes=[shape_node_output_shape], dtypes=[target_dtype],
46+
op_name_scope=node.name)
47+
nodes.append(shape_f_node)
48+
49+
split_attr = [1 for i in range(rank)]
50+
output_shapes = [[1] for i in range(rank)]
51+
output_dtypes = [target_dtype for i in range(rank)]
52+
split_node = ctx.make_node("Split", [shape_f_node.output[0]], output_count=rank,
53+
attr={"split": split_attr}, shapes=output_shapes,
54+
dtypes=output_dtypes, op_name_scope=node.name)
55+
nodes.append(split_node)
56+
57+
trip_cnts = []
58+
for i in range(rank):
59+
output_id = split_node.output[i]
60+
output_shape = ctx.get_shape(output_id)
61+
target_dtype = TensorProto.INT64
62+
shape_i_node = ctx.make_node("Cast", [output_id], attr={"to": target_dtype},
63+
shapes=[output_shape], dtypes=[target_dtype],
64+
op_name_scope=node.name)
65+
trip_cnts.append(shape_i_node.output[0])
66+
nodes.append(shape_i_node)
67+
# workaround ends
68+
69+
onnx_nodes = create_loop_op(ctx, node.input, true_data_type, true_data_shape, trip_cnts, rank)
70+
nodes.extend(onnx_nodes)
71+
loop_node = onnx_nodes[-1]
72+
73+
val_output_id = loop_node.output[1]
74+
elif rank == 0:
75+
if_node, val_output_id = create_if_op(ctx, node.input, true_data_type, true_data_shape)
76+
nodes.append(if_node)
77+
78+
ctx.copy_shape(node.output[0], val_output_id)
79+
ctx.set_dtype(node.output[0], true_data_type)
80+
81+
output_node = ctx.make_node("Identity", [val_output_id], outputs=node.output,
82+
shapes=[ctx.get_shape(val_output_id)], dtypes=[true_data_type])
83+
nodes.append(output_node)
84+
85+
return nodes
86+
87+
88+
# gather_input_ids is 1-D tensor, containing 3 elements:
89+
# 0: condition data to gather on
90+
# 1: true result to gather on
91+
# 2: false result to father on
92+
def create_loop_op(g, gather_input_ids, output_type, output_shape, trip_count_input_ids, rank):
93+
nodes = []
94+
cond_var_name = utils.make_name("cond_var")
95+
nodes.append(g.make_const(cond_var_name, np.array(True, dtype=np.bool)))
96+
97+
# Loop requires at least a variable, add a useless fake variable.
98+
fake_val_name = utils.make_name("fake_var")
99+
nodes.append(g.make_const(fake_val_name, np.array(0.0, dtype=np.float32)))
100+
101+
if rank < 1:
102+
raise ValueError("rank is < 1")
103+
trip_count_input_id = trip_count_input_ids[-1 * rank]
104+
105+
loop_inputs = [trip_count_input_id, # trip count
106+
cond_var_name, # termination condition
107+
fake_val_name # initial value of loop-carried dependencies
108+
]
109+
# define an extra scan output
110+
loop_node = g.make_node("Loop", loop_inputs, output_count=2, op_name_scope="select_loop",
111+
skip_conversion=False)
112+
loop_body = create_loop_body_graph(g, gather_input_ids, output_type, output_shape, trip_count_input_ids,
113+
rank, loop_node.name)
114+
loop_node.set_body_graph_as_attr("body", loop_body)
115+
nodes.append(loop_node)
116+
return nodes
117+
118+
119+
def get_inputs_for_current_iteration(g, input_id, iter_index):
120+
nodes = []
121+
cond_gather_node = g.make_node("Gather", [input_id, iter_index])
122+
nodes.append(cond_gather_node)
123+
124+
cur_cond_val_scalar_node = g.make_node("Squeeze", [cond_gather_node.output[0]], attr={"axes": [0]})
125+
nodes.append(cur_cond_val_scalar_node)
126+
127+
return nodes, cur_cond_val_scalar_node.output[0]
128+
129+
130+
def create_loop_body_graph(parent_g, gather_input_ids, output_data_type, output_shape, trip_count_input_ids,
131+
rank, loop_name):
132+
g = parent_g.create_new_graph_with_same_config()
133+
iter_name = utils.make_name("i")
134+
cond_name = utils.make_name("cond")
135+
fake_var_name = utils.make_name("fake_var")
136+
137+
g.add_graph_input(iter_name, TensorProto.INT64, (1,)) # iteration_num
138+
g.add_graph_input(cond_name, TensorProto.BOOL, ()) # condition
139+
g.add_graph_input(fake_var_name, TensorProto.FLOAT, ()) # loop-carried dependency
140+
nodes = g.get_nodes()
141+
# get the i'th value of condition
142+
cond_input_id = gather_input_ids[0]
143+
new_nodes, cond_input_id_for_current_iter = get_inputs_for_current_iteration(g, cond_input_id, iter_name)
144+
nodes.extend(new_nodes)
145+
146+
# get the i'th value of true values
147+
true_input_id = gather_input_ids[1]
148+
new_nodes, true_input_id_for_current_iter = get_inputs_for_current_iteration(g, true_input_id, iter_name)
149+
nodes.extend(new_nodes)
150+
151+
152+
# get the i'th value of false values
153+
false_input_id = gather_input_ids[2]
154+
new_nodes, false_input_id_for_current_iter = get_inputs_for_current_iteration(g, false_input_id, iter_name)
155+
nodes.extend(new_nodes)
156+
157+
input_ids_for_current_iter = [cond_input_id_for_current_iter, true_input_id_for_current_iter,
158+
false_input_id_for_current_iter]
159+
output_id = None
160+
rank = rank - 1
161+
if rank >= 1:
162+
nodes_1 = create_loop_op(g, input_ids_for_current_iter, output_data_type, output_shape[1:],
163+
trip_count_input_ids, rank)
164+
loop_1 = nodes_1[-1]
165+
output_id = loop_1.output[1]
166+
nodes.extend(nodes_1)
167+
elif rank == 0:
168+
if_node, if_node_output_id = create_if_op(g, input_ids_for_current_iter, output_data_type, output_shape[1:])
169+
output_id = if_node_output_id
170+
nodes.append(if_node)
171+
172+
output_identity_name = utils.make_name("loop_output")
173+
loop_output_id = utils.port_name(output_identity_name)
174+
loop_output_node = g.make_node(
175+
'Identity',
176+
[output_id],
177+
outputs=[loop_output_id],
178+
name=output_identity_name
179+
)
180+
nodes.append(loop_output_node)
181+
182+
cond_identity_name = utils.make_name("cond_output")
183+
cond_output_id = utils.port_name(cond_identity_name)
184+
identity_node = g.make_node(
185+
'Identity',
186+
[cond_name],
187+
outputs=[cond_output_id],
188+
name=cond_identity_name
189+
)
190+
nodes.append(identity_node)
191+
192+
fake_var_identity_name = utils.make_name("fake_var_output")
193+
fake_var_output_id = utils.port_name(fake_var_identity_name)
194+
identity_node = g.make_node(
195+
'Identity',
196+
[fake_var_name],
197+
outputs=[fake_var_output_id],
198+
name=fake_var_identity_name
199+
)
200+
nodes.append(identity_node)
201+
202+
g.set_nodes(nodes)
203+
204+
g.add_graph_output(cond_output_id, TensorProto.BOOL, ())
205+
g.add_graph_output(fake_var_output_id, TensorProto.FLOAT, ())
206+
207+
# use None for all dims, just keep original rank. Because it is observed, dims might be changed in loop.
208+
g.add_graph_output(loop_output_id, output_data_type, utils.create_vague_shape_like(output_shape[1:]))
209+
210+
return g
211+
212+
213+
def create_if_op(g, input_ids, output_data_type, output_shape):
214+
op_name = utils.make_name("If")
215+
true_graph = create_body_graph_for_if_branch(g, output_data_type, output_shape, input_ids[1], op_name)
216+
false_graph = create_body_graph_for_if_branch(g, output_data_type, output_shape, input_ids[2], op_name)
217+
out_name = port_name(op_name)
218+
219+
# output a scalar
220+
if_node = g.make_node("If", [input_ids[0]], outputs=[out_name], name=op_name, skip_conversion=False)
221+
if_node.set_body_graph_as_attr("then_branch", true_graph)
222+
if_node.set_body_graph_as_attr("else_branch", false_graph)
223+
return if_node, out_name
224+
225+
226+
def create_body_graph_for_if_branch(parent_g, data_type, output_shape, chosen_cur_cond_val_out_name, op_name):
227+
g = parent_g.create_new_graph_with_same_config()
228+
nodes = []
229+
name = utils.make_name("Identity")
230+
identity_node = g.make_node(
231+
'Identity',
232+
inputs=[chosen_cur_cond_val_out_name],
233+
outputs=['y'],
234+
name=name
235+
)
236+
nodes.append(identity_node)
237+
g.set_nodes(nodes)
238+
g.add_graph_output("y", data_type, utils.create_vague_shape_like(output_shape))
239+
return g

tf2onnx/tfonnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1834,11 +1834,11 @@ def where_op(ctx, node, name, args):
18341834
"If": (direct_op, []),
18351835
"Loop": (direct_op, []),
18361836
"Scan": (direct_op, []),
1837-
"Select": (select_op7, []),
18381837
}
18391838

18401839
_OPSET_8 = {
18411840
"ReverseSequence": (reverse_op8, []), # make use of scan
1841+
"Select": (select_op8, []),
18421842
}
18431843

18441844
_OPSET_9 = {

0 commit comments

Comments
 (0)