|
4 | 4 | """
|
5 | 5 | tf2onnx.tf2onnx - select op conversion
|
6 | 6 | """
|
7 |
| -import numpy as np |
8 |
| -from onnx.onnx_pb import TensorProto |
| 7 | + |
9 | 8 | from tf2onnx import utils
|
10 |
| -from tf2onnx.utils import port_name, make_sure |
11 | 9 |
|
12 | 10 |
|
13 | 11 | # pylint: disable=unused-argument,missing-docstring
|
14 | 12 |
|
15 | 13 |
|
16 |
| -def select_op8(ctx, node, name, args): |
| 14 | +def select_op7(ctx, node, name, args): |
17 | 15 | # T output = Select(bool condition, T x, T y)
|
18 |
| - # V v_final_and_scan_outputs = Loop(int64 M, B cond, V v_initial) |
19 |
| - utils.make_sure(len(node.input) > 1, "Select with only condition is not supported.") |
20 |
| - |
21 |
| - nodes = [] |
22 |
| - true_data_type = ctx.get_dtype(node.input[1]) |
23 |
| - true_data_shape = ctx.get_shape(node.input[1]) |
24 |
| - make_sure(true_data_type is not None, "select true data dtype cannot be None") |
25 |
| - make_sure(true_data_shape is not None, "select true data shape cannot be None") |
26 |
| - |
27 |
| - condition_shape = ctx.get_shape(node.input[0]) |
28 |
| - utils.make_sure(condition_shape is not None, "condition shape is None") |
29 |
| - rank = len(condition_shape) |
30 |
| - |
31 |
| - utils.make_sure(rank >= 0, "rank should be >= 0") |
32 |
| - val_output_id = None |
33 |
| - if rank > 0: |
34 |
| - # create nodes getting shape of condition |
35 |
| - shape_node_output_shape = [rank] |
36 |
| - shape_node = ctx.make_node("Shape", [node.input[0]], op_name_scope=node.name, |
37 |
| - shapes=[shape_node_output_shape], dtypes=[TensorProto.INT64]) |
38 |
| - nodes.append(shape_node) |
39 |
| - |
40 |
| - # todo(pengwa), move those leveraging rewrite_incomplete_type_support_onnxruntime after shape inferencing |
41 |
| - # bug is fixed. |
42 |
| - # workaround: onnxruntime does not support Split-2, add cases before and after. |
43 |
| - target_dtype = TensorProto.FLOAT |
44 |
| - shape_f_node = ctx.make_node("Cast", [shape_node.output[0]], attr={"to": target_dtype}, |
45 |
| - shapes=[shape_node_output_shape], dtypes=[target_dtype], |
46 |
| - op_name_scope=node.name) |
47 |
| - nodes.append(shape_f_node) |
48 |
| - |
49 |
| - split_attr = [1 for i in range(rank)] |
50 |
| - output_shapes = [[1] for i in range(rank)] |
51 |
| - output_dtypes = [target_dtype for i in range(rank)] |
52 |
| - split_node = ctx.make_node("Split", [shape_f_node.output[0]], output_count=rank, |
53 |
| - attr={"split": split_attr}, shapes=output_shapes, |
54 |
| - dtypes=output_dtypes, op_name_scope=node.name) |
55 |
| - nodes.append(split_node) |
56 |
| - |
57 |
| - trip_cnts = [] |
58 |
| - for i in range(rank): |
59 |
| - output_id = split_node.output[i] |
60 |
| - output_shape = ctx.get_shape(output_id) |
61 |
| - target_dtype = TensorProto.INT64 |
62 |
| - shape_i_node = ctx.make_node("Cast", [output_id], attr={"to": target_dtype}, |
63 |
| - shapes=[output_shape], dtypes=[target_dtype], |
64 |
| - op_name_scope=node.name) |
65 |
| - trip_cnts.append(shape_i_node.output[0]) |
66 |
| - nodes.append(shape_i_node) |
67 |
| - # workaround ends |
68 |
| - |
69 |
| - onnx_nodes = create_loop_op(ctx, node.input, true_data_type, true_data_shape, trip_cnts, rank) |
70 |
| - nodes.extend(onnx_nodes) |
71 |
| - loop_node = onnx_nodes[-1] |
72 |
| - |
73 |
| - val_output_id = loop_node.output[1] |
74 |
| - elif rank == 0: |
75 |
| - if_node, val_output_id = create_if_op(ctx, node.input, true_data_type, true_data_shape) |
76 |
| - nodes.append(if_node) |
77 |
| - |
78 |
| - ctx.copy_shape(node.output[0], val_output_id) |
79 |
| - ctx.set_dtype(node.output[0], true_data_type) |
80 |
| - |
81 |
| - output_node = ctx.make_node("Identity", [val_output_id], outputs=node.output, |
82 |
| - shapes=[ctx.get_shape(val_output_id)], dtypes=[true_data_type]) |
83 |
| - nodes.append(output_node) |
84 |
| - |
85 |
| - return nodes |
86 |
| - |
87 |
| - |
88 |
| -# gather_input_ids is 1-D tensor, containing 3 elements: |
89 |
| -# 0: condition data to gather on |
90 |
| -# 1: true result to gather on |
91 |
| -# 2: false result to father on |
92 |
| -def create_loop_op(g, gather_input_ids, output_type, output_shape, trip_count_input_ids, rank): |
93 |
| - nodes = [] |
94 |
| - cond_var_name = utils.make_name("cond_var") |
95 |
| - nodes.append(g.make_const(cond_var_name, np.array(True, dtype=np.bool))) |
96 |
| - |
97 |
| - # Loop requires at least a variable, add a useless fake variable. |
98 |
| - fake_val_name = utils.make_name("fake_var") |
99 |
| - nodes.append(g.make_const(fake_val_name, np.array(0.0, dtype=np.float32))) |
100 |
| - |
101 |
| - if rank < 1: |
102 |
| - raise ValueError("rank is < 1") |
103 |
| - trip_count_input_id = trip_count_input_ids[-1 * rank] |
104 |
| - |
105 |
| - loop_inputs = [trip_count_input_id, # trip count |
106 |
| - cond_var_name, # termination condition |
107 |
| - fake_val_name # initial value of loop-carried dependencies |
108 |
| - ] |
109 |
| - # define an extra scan output |
110 |
| - loop_node = g.make_node("Loop", loop_inputs, output_count=2, op_name_scope="select_loop", |
111 |
| - skip_conversion=False) |
112 |
| - loop_body = create_loop_body_graph(g, gather_input_ids, output_type, output_shape, trip_count_input_ids, |
113 |
| - rank, loop_node.name) |
114 |
| - loop_node.set_body_graph_as_attr("body", loop_body) |
115 |
| - nodes.append(loop_node) |
116 |
| - return nodes |
117 |
| - |
118 |
| - |
119 |
| -def get_inputs_for_current_iteration(g, input_id, iter_index): |
120 |
| - nodes = [] |
121 |
| - cond_gather_node = g.make_node("Gather", [input_id, iter_index]) |
122 |
| - nodes.append(cond_gather_node) |
123 |
| - |
124 |
| - cur_cond_val_scalar_node = g.make_node("Squeeze", [cond_gather_node.output[0]], attr={"axes": [0]}) |
125 |
| - nodes.append(cur_cond_val_scalar_node) |
126 |
| - |
127 |
| - return nodes, cur_cond_val_scalar_node.output[0] |
128 |
| - |
129 |
| - |
130 |
| -def create_loop_body_graph(parent_g, gather_input_ids, output_data_type, output_shape, trip_count_input_ids, |
131 |
| - rank, loop_name): |
132 |
| - g = parent_g.create_new_graph_with_same_config() |
133 |
| - iter_name = utils.make_name("i") |
134 |
| - cond_name = utils.make_name("cond") |
135 |
| - fake_var_name = utils.make_name("fake_var") |
136 |
| - |
137 |
| - g.add_graph_input(iter_name, TensorProto.INT64, (1,)) # iteration_num |
138 |
| - g.add_graph_input(cond_name, TensorProto.BOOL, ()) # condition |
139 |
| - g.add_graph_input(fake_var_name, TensorProto.FLOAT, ()) # loop-carried dependency |
140 |
| - nodes = g.get_nodes() |
141 |
| - # get the i'th value of condition |
142 |
| - cond_input_id = gather_input_ids[0] |
143 |
| - new_nodes, cond_input_id_for_current_iter = get_inputs_for_current_iteration(g, cond_input_id, iter_name) |
144 |
| - nodes.extend(new_nodes) |
145 |
| - |
146 |
| - # get the i'th value of true values |
147 |
| - true_input_id = gather_input_ids[1] |
148 |
| - new_nodes, true_input_id_for_current_iter = get_inputs_for_current_iteration(g, true_input_id, iter_name) |
149 |
| - nodes.extend(new_nodes) |
150 |
| - |
151 |
| - |
152 |
| - # get the i'th value of false values |
153 |
| - false_input_id = gather_input_ids[2] |
154 |
| - new_nodes, false_input_id_for_current_iter = get_inputs_for_current_iteration(g, false_input_id, iter_name) |
155 |
| - nodes.extend(new_nodes) |
156 |
| - |
157 |
| - input_ids_for_current_iter = [cond_input_id_for_current_iter, true_input_id_for_current_iter, |
158 |
| - false_input_id_for_current_iter] |
159 |
| - output_id = None |
160 |
| - rank = rank - 1 |
161 |
| - if rank >= 1: |
162 |
| - nodes_1 = create_loop_op(g, input_ids_for_current_iter, output_data_type, output_shape[1:], |
163 |
| - trip_count_input_ids, rank) |
164 |
| - loop_1 = nodes_1[-1] |
165 |
| - output_id = loop_1.output[1] |
166 |
| - nodes.extend(nodes_1) |
167 |
| - elif rank == 0: |
168 |
| - if_node, if_node_output_id = create_if_op(g, input_ids_for_current_iter, output_data_type, output_shape[1:]) |
169 |
| - output_id = if_node_output_id |
170 |
| - nodes.append(if_node) |
171 |
| - |
172 |
| - output_identity_name = utils.make_name("loop_output") |
173 |
| - loop_output_id = utils.port_name(output_identity_name) |
174 |
| - loop_output_node = g.make_node( |
175 |
| - 'Identity', |
176 |
| - [output_id], |
177 |
| - outputs=[loop_output_id], |
178 |
| - name=output_identity_name |
179 |
| - ) |
180 |
| - nodes.append(loop_output_node) |
181 |
| - |
182 |
| - cond_identity_name = utils.make_name("cond_output") |
183 |
| - cond_output_id = utils.port_name(cond_identity_name) |
184 |
| - identity_node = g.make_node( |
185 |
| - 'Identity', |
186 |
| - [cond_name], |
187 |
| - outputs=[cond_output_id], |
188 |
| - name=cond_identity_name |
189 |
| - ) |
190 |
| - nodes.append(identity_node) |
191 |
| - |
192 |
| - fake_var_identity_name = utils.make_name("fake_var_output") |
193 |
| - fake_var_output_id = utils.port_name(fake_var_identity_name) |
194 |
| - identity_node = g.make_node( |
195 |
| - 'Identity', |
196 |
| - [fake_var_name], |
197 |
| - outputs=[fake_var_output_id], |
198 |
| - name=fake_var_identity_name |
199 |
| - ) |
200 |
| - nodes.append(identity_node) |
201 |
| - |
202 |
| - g.set_nodes(nodes) |
203 |
| - |
204 |
| - g.add_graph_output(cond_output_id, TensorProto.BOOL, ()) |
205 |
| - g.add_graph_output(fake_var_output_id, TensorProto.FLOAT, ()) |
206 |
| - |
207 |
| - # use None for all dims, just keep original rank. Because it is observed, dims might be changed in loop. |
208 |
| - g.add_graph_output(loop_output_id, output_data_type, utils.create_vague_shape_like(output_shape[1:])) |
209 |
| - |
210 |
| - return g |
211 |
| - |
212 |
| - |
213 |
| -def create_if_op(g, input_ids, output_data_type, output_shape): |
214 |
| - op_name = utils.make_name("If") |
215 |
| - true_graph = create_body_graph_for_if_branch(g, output_data_type, output_shape, input_ids[1], op_name) |
216 |
| - false_graph = create_body_graph_for_if_branch(g, output_data_type, output_shape, input_ids[2], op_name) |
217 |
| - out_name = port_name(op_name) |
218 |
| - |
219 |
| - # output a scalar |
220 |
| - if_node = g.make_node("If", [input_ids[0]], outputs=[out_name], name=op_name, skip_conversion=False) |
221 |
| - if_node.set_body_graph_as_attr("then_branch", true_graph) |
222 |
| - if_node.set_body_graph_as_attr("else_branch", false_graph) |
223 |
| - return if_node, out_name |
224 |
| - |
225 |
| - |
226 |
| -def create_body_graph_for_if_branch(parent_g, data_type, output_shape, chosen_cur_cond_val_out_name, op_name): |
227 |
| - g = parent_g.create_new_graph_with_same_config() |
228 |
| - nodes = [] |
229 |
| - name = utils.make_name("Identity") |
230 |
| - identity_node = g.make_node( |
231 |
| - 'Identity', |
232 |
| - inputs=[chosen_cur_cond_val_out_name], |
233 |
| - outputs=['y'], |
234 |
| - name=name |
235 |
| - ) |
236 |
| - nodes.append(identity_node) |
237 |
| - g.set_nodes(nodes) |
238 |
| - g.add_graph_output("y", data_type, utils.create_vague_shape_like(output_shape)) |
239 |
| - return g |
| 16 | + select_x_dtype = ctx.get_dtype(node.input[1]) |
| 17 | + select_x_shape = ctx.get_shape(node.input[1]) |
| 18 | + cond = node.inputs[0] |
| 19 | + cond_shape = ctx.get_shape(cond.output[0]) |
| 20 | + utils.make_sure(select_x_shape is not None and cond_shape is not None, "rank of inputs are needed") |
| 21 | + |
| 22 | + added_nodes = [] |
| 23 | + true_mask = ctx.make_node("Cast", cond.output, attr={"to": select_x_dtype}) |
| 24 | + cond_not = ctx.make_node("Not", cond.output) |
| 25 | + false_mask = ctx.make_node("Cast", cond_not.output, attr={"to": select_x_dtype}) |
| 26 | + added_nodes.extend([true_mask, cond_not, false_mask]) |
| 27 | + # the broadcasting rule of select is different with common rule. |
| 28 | + # for example, shape of input_x is (10), shape of input_y is (10, 2) then common broadcasting rule will fail |
| 29 | + # while in tf "select", input_x will become (10, 1) and repeat at last dimension |
| 30 | + # so reshape node is inserted here |
| 31 | + unsqueeze_dim_num = len(select_x_shape) - len(cond_shape) |
| 32 | + utils.make_sure(unsqueeze_dim_num >= 0, "dim of select_x must not less than cond") |
| 33 | + if unsqueeze_dim_num != 0: |
| 34 | + unsqueeze_dim_start = len(select_x_shape) |
| 35 | + axes = range(unsqueeze_dim_start-1, unsqueeze_dim_start+unsqueeze_dim_num-1) |
| 36 | + true_mask = ctx.make_node("Unsqueeze", true_mask.output, attr={"axes": axes}) |
| 37 | + false_mask = ctx.make_node("Unsqueeze", false_mask.output, attr={"axes": axes}) |
| 38 | + added_nodes.extend([true_mask, false_mask]) |
| 39 | + |
| 40 | + select_from_true = ctx.make_node("Mul", [true_mask.output[0], node.input[1]]) |
| 41 | + select_from_false = ctx.make_node("Mul", [false_mask.output[0], node.input[2]]) |
| 42 | + res = ctx.make_node("Add", [select_from_true.output[0], select_from_false.output[0]], |
| 43 | + name=node.name, outputs=node.output, |
| 44 | + shapes=[ctx.get_shape(node.output[0])], dtypes=[ctx.get_dtype(node.output[0])]) |
| 45 | + return [*added_nodes, select_from_true, select_from_false, res] |
0 commit comments