Skip to content

Commit 9db30fb

Browse files
committed
replace conversion logic of "select" with a simpler one
1 parent 368843a commit 9db30fb

File tree

4 files changed

+40
-235
lines changed

4 files changed

+40
-235
lines changed

tests/test_backend.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1446,8 +1446,7 @@ def test_reverse_sequence_time_major(self):
14461446
_ = tf.identity(x_, name=_TFOUTPUT)
14471447
self._run_test_case([_OUTPUT], {_INPUT: x_val})
14481448

1449-
# @unittest.skipIf(OPSET < 8, "supported with opset 8 or better")
1450-
@unittest.skip("FIXME: the newest onnxruntime wheel hasn't been published to PYPI, so Select op is not supported")
1449+
@check_opset_min_version(7, "where")
14511450
def test_where(self):
14521451
x_val = np.array([1, 2, -3, 4, -5, -6, -7, 8, 9, 0], dtype=np.int32)
14531452
true_result = np.array([111, 222, 333, 444, 555, 666, 777, 888, 999, 1000],
@@ -1459,7 +1458,7 @@ def test_where(self):
14591458
_ = tf.identity(picks, name=_TFOUTPUT)
14601459
self._run_test_case([_OUTPUT], {_INPUT: x_val})
14611460

1462-
@check_opset_min_version(8, "where")
1461+
@check_opset_min_version(7, "where")
14631462
def test_where_with_two_rank_input(self):
14641463
x_val = np.array([1, 2, -3, 4, -5, -6, -7, 8, 9, 0], dtype=np.int32)
14651464
true_result = np.array([[111, 111], [222, 222], [333, 333], [444, 444], [555, 555],
@@ -1475,7 +1474,7 @@ def test_where_with_two_rank_input(self):
14751474

14761475
self._run_test_case([_OUTPUT], {_INPUT: x_val})
14771476

1478-
@check_opset_min_version(8, "where")
1477+
@check_opset_min_version(7, "where")
14791478
def test_where_with_two_rank_condition(self):
14801479
x_val = np.array([[1, 2, -3, 4, -5, -6, -7, 8, 9, 0]], dtype=np.int32)
14811480
true_result = np.array([[111, 222, 333, 444, 555, 666, 777, 888, 999, 1000]],
@@ -1488,7 +1487,7 @@ def test_where_with_two_rank_condition(self):
14881487

14891488
self._run_test_case([_OUTPUT], {_INPUT: x_val})
14901489

1491-
@check_opset_min_version(8, "where")
1490+
@check_opset_min_version(7, "where")
14921491
def test_where_with_three_rank_condition(self):
14931492
x_val = np.array([[[1, 2, -3, 4, -5, -6, -7, 8, 9, 0]]], dtype=np.int32)
14941493
true_result = np.array([[[111, 222, 333, 444, 555, 666, 777, 888, 999, 1000]]],
@@ -1501,7 +1500,7 @@ def test_where_with_three_rank_condition(self):
15011500

15021501
self._run_test_case([_OUTPUT], {_INPUT: x_val})
15031502

1504-
@check_opset_min_version(8, "where")
1503+
@check_opset_min_version(7, "where")
15051504
def test_where_scalar(self):
15061505
x_val = np.array(6, dtype=np.int32)
15071506
true_result = np.array([111, 222, 333, 444, 555, 666, 777, 888, 999, 1000],

tf2onnx/function/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from .gathernd import gathernd_op
1010
from .matrixbandpart import matrixbandpart_op
1111
from .range import range_op7
12-
from .select import select_op8
12+
from .select import select_op7
1313
from .sparse_softmax_cross_entropy_with_logits import sparse_softmax_cross_entropy_with_logits_op
1414

15-
__all__ = ["gathernd_op", "matrixbandpart_op", "range_op7", "select_op8", "sparse_softmax_cross_entropy_with_logits_op"]
15+
__all__ = ["gathernd_op", "matrixbandpart_op", "range_op7", "select_op7", "sparse_softmax_cross_entropy_with_logits_op"]

tf2onnx/function/select.py

Lines changed: 32 additions & 226 deletions
Original file line numberDiff line numberDiff line change
@@ -4,236 +4,42 @@
44
"""
55
tf2onnx.tf2onnx - select op conversion
66
"""
7-
import numpy as np
8-
from onnx.onnx_pb import TensorProto
7+
98
from tf2onnx import utils
10-
from tf2onnx.utils import port_name, make_sure
119

1210

1311
# pylint: disable=unused-argument,missing-docstring
1412

1513

16-
def select_op8(ctx, node, name, args):
14+
def select_op7(ctx, node, name, args):
1715
# T output = Select(bool condition, T x, T y)
18-
# V v_final_and_scan_outputs = Loop(int64 M, B cond, V v_initial)
19-
utils.make_sure(len(node.input) > 1, "Select with only condition is not supported.")
20-
21-
nodes = []
22-
true_data_type = ctx.get_dtype(node.input[1])
23-
true_data_shape = ctx.get_shape(node.input[1])
24-
make_sure(true_data_type is not None, "select true data dtype cannot be None")
25-
make_sure(true_data_shape is not None, "select true data shape cannot be None")
26-
27-
condition_shape = ctx.get_shape(node.input[0])
28-
utils.make_sure(condition_shape is not None, "condition shape is None")
29-
rank = len(condition_shape)
30-
31-
utils.make_sure(rank >= 0, "rank should be >= 0")
32-
val_output_id = None
33-
if rank > 0:
34-
# create nodes getting shape of condition
35-
shape_node_output_shape = [rank]
36-
shape_node = ctx.make_node("Shape", [node.input[0]], op_name_scope=node.name,
37-
shapes=[shape_node_output_shape], dtypes=[TensorProto.INT64])
38-
nodes.append(shape_node)
39-
40-
# todo(pengwa), move those leveraging rewrite_incomplete_type_support_onnxruntime after shape inferencing
41-
# bug is fixed.
42-
# workaround: onnxruntime does not support Split-2, add cases before and after.
43-
target_dtype = TensorProto.FLOAT
44-
shape_f_node = ctx.make_node("Cast", [shape_node.output[0]], attr={"to": target_dtype},
45-
shapes=[shape_node_output_shape], dtypes=[target_dtype],
46-
op_name_scope=node.name)
47-
nodes.append(shape_f_node)
48-
49-
split_attr = [1 for i in range(rank)]
50-
output_shapes = [[1] for i in range(rank)]
51-
output_dtypes = [target_dtype for i in range(rank)]
52-
split_node = ctx.make_node("Split", [shape_f_node.output[0]], output_count=rank,
53-
attr={"split": split_attr}, shapes=output_shapes,
54-
dtypes=output_dtypes, op_name_scope=node.name)
55-
nodes.append(split_node)
56-
57-
trip_cnts = []
58-
for i in range(rank):
59-
output_id = split_node.output[i]
60-
output_shape = ctx.get_shape(output_id)
61-
target_dtype = TensorProto.INT64
62-
shape_i_node = ctx.make_node("Cast", [output_id], attr={"to": target_dtype},
63-
shapes=[output_shape], dtypes=[target_dtype],
64-
op_name_scope=node.name)
65-
trip_cnts.append(shape_i_node.output[0])
66-
nodes.append(shape_i_node)
67-
# workaround ends
68-
69-
onnx_nodes = create_loop_op(ctx, node.input, true_data_type, true_data_shape, trip_cnts, rank)
70-
nodes.extend(onnx_nodes)
71-
loop_node = onnx_nodes[-1]
72-
73-
val_output_id = loop_node.output[1]
74-
elif rank == 0:
75-
if_node, val_output_id = create_if_op(ctx, node.input, true_data_type, true_data_shape)
76-
nodes.append(if_node)
77-
78-
ctx.copy_shape(node.output[0], val_output_id)
79-
ctx.set_dtype(node.output[0], true_data_type)
80-
81-
output_node = ctx.make_node("Identity", [val_output_id], outputs=node.output,
82-
shapes=[ctx.get_shape(val_output_id)], dtypes=[true_data_type])
83-
nodes.append(output_node)
84-
85-
return nodes
86-
87-
88-
# gather_input_ids is 1-D tensor, containing 3 elements:
89-
# 0: condition data to gather on
90-
# 1: true result to gather on
91-
# 2: false result to father on
92-
def create_loop_op(g, gather_input_ids, output_type, output_shape, trip_count_input_ids, rank):
93-
nodes = []
94-
cond_var_name = utils.make_name("cond_var")
95-
nodes.append(g.make_const(cond_var_name, np.array(True, dtype=np.bool)))
96-
97-
# Loop requires at least a variable, add a useless fake variable.
98-
fake_val_name = utils.make_name("fake_var")
99-
nodes.append(g.make_const(fake_val_name, np.array(0.0, dtype=np.float32)))
100-
101-
if rank < 1:
102-
raise ValueError("rank is < 1")
103-
trip_count_input_id = trip_count_input_ids[-1 * rank]
104-
105-
loop_inputs = [trip_count_input_id, # trip count
106-
cond_var_name, # termination condition
107-
fake_val_name # initial value of loop-carried dependencies
108-
]
109-
# define an extra scan output
110-
loop_node = g.make_node("Loop", loop_inputs, output_count=2, op_name_scope="select_loop",
111-
skip_conversion=False)
112-
loop_body = create_loop_body_graph(g, gather_input_ids, output_type, output_shape, trip_count_input_ids,
113-
rank, loop_node.name)
114-
loop_node.set_body_graph_as_attr("body", loop_body)
115-
nodes.append(loop_node)
116-
return nodes
117-
118-
119-
def get_inputs_for_current_iteration(g, input_id, iter_index):
120-
nodes = []
121-
cond_gather_node = g.make_node("Gather", [input_id, iter_index])
122-
nodes.append(cond_gather_node)
123-
124-
cur_cond_val_scalar_node = g.make_node("Squeeze", [cond_gather_node.output[0]], attr={"axes": [0]})
125-
nodes.append(cur_cond_val_scalar_node)
126-
127-
return nodes, cur_cond_val_scalar_node.output[0]
128-
129-
130-
def create_loop_body_graph(parent_g, gather_input_ids, output_data_type, output_shape, trip_count_input_ids,
131-
rank, loop_name):
132-
g = parent_g.create_new_graph_with_same_config()
133-
iter_name = utils.make_name("i")
134-
cond_name = utils.make_name("cond")
135-
fake_var_name = utils.make_name("fake_var")
136-
137-
g.add_graph_input(iter_name, TensorProto.INT64, (1,)) # iteration_num
138-
g.add_graph_input(cond_name, TensorProto.BOOL, ()) # condition
139-
g.add_graph_input(fake_var_name, TensorProto.FLOAT, ()) # loop-carried dependency
140-
nodes = g.get_nodes()
141-
# get the i'th value of condition
142-
cond_input_id = gather_input_ids[0]
143-
new_nodes, cond_input_id_for_current_iter = get_inputs_for_current_iteration(g, cond_input_id, iter_name)
144-
nodes.extend(new_nodes)
145-
146-
# get the i'th value of true values
147-
true_input_id = gather_input_ids[1]
148-
new_nodes, true_input_id_for_current_iter = get_inputs_for_current_iteration(g, true_input_id, iter_name)
149-
nodes.extend(new_nodes)
150-
151-
152-
# get the i'th value of false values
153-
false_input_id = gather_input_ids[2]
154-
new_nodes, false_input_id_for_current_iter = get_inputs_for_current_iteration(g, false_input_id, iter_name)
155-
nodes.extend(new_nodes)
156-
157-
input_ids_for_current_iter = [cond_input_id_for_current_iter, true_input_id_for_current_iter,
158-
false_input_id_for_current_iter]
159-
output_id = None
160-
rank = rank - 1
161-
if rank >= 1:
162-
nodes_1 = create_loop_op(g, input_ids_for_current_iter, output_data_type, output_shape[1:],
163-
trip_count_input_ids, rank)
164-
loop_1 = nodes_1[-1]
165-
output_id = loop_1.output[1]
166-
nodes.extend(nodes_1)
167-
elif rank == 0:
168-
if_node, if_node_output_id = create_if_op(g, input_ids_for_current_iter, output_data_type, output_shape[1:])
169-
output_id = if_node_output_id
170-
nodes.append(if_node)
171-
172-
output_identity_name = utils.make_name("loop_output")
173-
loop_output_id = utils.port_name(output_identity_name)
174-
loop_output_node = g.make_node(
175-
'Identity',
176-
[output_id],
177-
outputs=[loop_output_id],
178-
name=output_identity_name
179-
)
180-
nodes.append(loop_output_node)
181-
182-
cond_identity_name = utils.make_name("cond_output")
183-
cond_output_id = utils.port_name(cond_identity_name)
184-
identity_node = g.make_node(
185-
'Identity',
186-
[cond_name],
187-
outputs=[cond_output_id],
188-
name=cond_identity_name
189-
)
190-
nodes.append(identity_node)
191-
192-
fake_var_identity_name = utils.make_name("fake_var_output")
193-
fake_var_output_id = utils.port_name(fake_var_identity_name)
194-
identity_node = g.make_node(
195-
'Identity',
196-
[fake_var_name],
197-
outputs=[fake_var_output_id],
198-
name=fake_var_identity_name
199-
)
200-
nodes.append(identity_node)
201-
202-
g.set_nodes(nodes)
203-
204-
g.add_graph_output(cond_output_id, TensorProto.BOOL, ())
205-
g.add_graph_output(fake_var_output_id, TensorProto.FLOAT, ())
206-
207-
# use None for all dims, just keep original rank. Because it is observed, dims might be changed in loop.
208-
g.add_graph_output(loop_output_id, output_data_type, utils.create_vague_shape_like(output_shape[1:]))
209-
210-
return g
211-
212-
213-
def create_if_op(g, input_ids, output_data_type, output_shape):
214-
op_name = utils.make_name("If")
215-
true_graph = create_body_graph_for_if_branch(g, output_data_type, output_shape, input_ids[1], op_name)
216-
false_graph = create_body_graph_for_if_branch(g, output_data_type, output_shape, input_ids[2], op_name)
217-
out_name = port_name(op_name)
218-
219-
# output a scalar
220-
if_node = g.make_node("If", [input_ids[0]], outputs=[out_name], name=op_name, skip_conversion=False)
221-
if_node.set_body_graph_as_attr("then_branch", true_graph)
222-
if_node.set_body_graph_as_attr("else_branch", false_graph)
223-
return if_node, out_name
224-
225-
226-
def create_body_graph_for_if_branch(parent_g, data_type, output_shape, chosen_cur_cond_val_out_name, op_name):
227-
g = parent_g.create_new_graph_with_same_config()
228-
nodes = []
229-
name = utils.make_name("Identity")
230-
identity_node = g.make_node(
231-
'Identity',
232-
inputs=[chosen_cur_cond_val_out_name],
233-
outputs=['y'],
234-
name=name
235-
)
236-
nodes.append(identity_node)
237-
g.set_nodes(nodes)
238-
g.add_graph_output("y", data_type, utils.create_vague_shape_like(output_shape))
239-
return g
16+
select_x_dtype = ctx.get_dtype(node.input[1])
17+
select_x_shape = ctx.get_shape(node.input[1])
18+
cond = node.inputs[0]
19+
cond_shape = ctx.get_shape(cond.output[0])
20+
utils.make_sure(select_x_shape is not None and cond_shape is not None, "rank of inputs are needed")
21+
22+
added_nodes = []
23+
true_mask = ctx.make_node("Cast", cond.output, attr={"to": select_x_dtype})
24+
cond_not = ctx.make_node("Not", cond.output)
25+
false_mask = ctx.make_node("Cast", cond_not.output, attr={"to": select_x_dtype})
26+
added_nodes.extend([true_mask, cond_not, false_mask])
27+
# the broadcasting rule of select is different with common rule.
28+
# for example, shape of input_x is (10), shape of input_y is (10, 2) then common broadcasting rule will fail
29+
# while in tf "select", input_x will become (10, 1) and repeat at last dimension
30+
# so reshape node is inserted here
31+
unsqueeze_dim_num = len(select_x_shape) - len(cond_shape)
32+
utils.make_sure(unsqueeze_dim_num >= 0, "dim of select_x must not less than cond")
33+
if unsqueeze_dim_num != 0:
34+
unsqueeze_dim_start = len(select_x_shape)
35+
axes = range(unsqueeze_dim_start-1, unsqueeze_dim_start+unsqueeze_dim_num-1)
36+
true_mask = ctx.make_node("Unsqueeze", true_mask.output, attr={"axes": axes})
37+
false_mask = ctx.make_node("Unsqueeze", false_mask.output, attr={"axes": axes})
38+
added_nodes.extend([true_mask, false_mask])
39+
40+
select_from_true = ctx.make_node("Mul", [true_mask.output[0], node.input[1]])
41+
select_from_false = ctx.make_node("Mul", [false_mask.output[0], node.input[2]])
42+
res = ctx.make_node("Add", [select_from_true.output[0], select_from_false.output[0]],
43+
name=node.name, outputs=node.output,
44+
shapes=[ctx.get_shape(node.output[0])], dtypes=[ctx.get_dtype(node.output[0])])
45+
return [*added_nodes, select_from_true, select_from_false, res]

tf2onnx/tfonnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1834,11 +1834,11 @@ def where_op(ctx, node, name, args):
18341834
"If": (direct_op, []),
18351835
"Loop": (direct_op, []),
18361836
"Scan": (direct_op, []),
1837+
"Select": (select_op7, []),
18371838
}
18381839

18391840
_OPSET_8 = {
18401841
"ReverseSequence": (reverse_op8, []), # make use of scan
1841-
"Select": (select_op8, []),
18421842
}
18431843

18441844
_OPSET_9 = {

0 commit comments

Comments
 (0)