|
4 | 4 | """
|
5 | 5 | tf2onnx.tf2onnx - select op conversion
|
6 | 6 | """
|
7 |
| - |
| 7 | +import numpy as np |
| 8 | +from onnx.onnx_pb import TensorProto |
8 | 9 | from tf2onnx import utils
|
| 10 | +from tf2onnx.utils import port_name, make_sure |
9 | 11 |
|
10 | 12 |
|
11 | 13 | # pylint: disable=unused-argument,missing-docstring
|
12 | 14 |
|
13 | 15 |
|
14 |
| -def select_op7(ctx, node, name, args): |
| 16 | +def select_op8(ctx, node, name, args): |
15 | 17 | # T output = Select(bool condition, T x, T y)
|
16 |
| - select_x_dtype = ctx.get_dtype(node.input[1]) |
17 |
| - select_x_shape = ctx.get_shape(node.input[1]) |
18 |
| - cond = node.inputs[0] |
19 |
| - cond_shape = ctx.get_shape(cond.output[0]) |
20 |
| - utils.make_sure(select_x_shape is not None and cond_shape is not None, "rank of inputs are needed") |
21 |
| - |
22 |
| - added_nodes = [] |
23 |
| - true_mask = ctx.make_node("Cast", cond.output, attr={"to": select_x_dtype}) |
24 |
| - cond_not = ctx.make_node("Not", cond.output) |
25 |
| - false_mask = ctx.make_node("Cast", cond_not.output, attr={"to": select_x_dtype}) |
26 |
| - added_nodes.extend([true_mask, cond_not, false_mask]) |
27 |
| - # the broadcasting rule of select is different with common rule. |
28 |
| - # for example, shape of input_x is (10), shape of input_y is (10, 2) then common broadcasting rule will fail |
29 |
| - # while in tf "select", input_x will become (10, 1) and repeat at last dimension |
30 |
| - # so reshape node is inserted here |
31 |
| - unsqueeze_dim_num = len(select_x_shape) - len(cond_shape) |
32 |
| - utils.make_sure(unsqueeze_dim_num >= 0, "dim of select_x must not less than cond") |
33 |
| - if unsqueeze_dim_num != 0: |
34 |
| - unsqueeze_dim_start = len(select_x_shape) |
35 |
| - axes = range(unsqueeze_dim_start-1, unsqueeze_dim_start+unsqueeze_dim_num-1) |
36 |
| - true_mask = ctx.make_node("Unsqueeze", true_mask.output, attr={"axes": axes}) |
37 |
| - false_mask = ctx.make_node("Unsqueeze", false_mask.output, attr={"axes": axes}) |
38 |
| - added_nodes.extend([true_mask, false_mask]) |
39 |
| - |
40 |
| - select_from_true = ctx.make_node("Mul", [true_mask.output[0], node.input[1]]) |
41 |
| - select_from_false = ctx.make_node("Mul", [false_mask.output[0], node.input[2]]) |
42 |
| - res = ctx.make_node("Add", [select_from_true.output[0], select_from_false.output[0]], |
43 |
| - name=node.name, outputs=node.output, |
44 |
| - shapes=[ctx.get_shape(node.output[0])], dtypes=[ctx.get_dtype(node.output[0])]) |
45 |
| - return [*added_nodes, select_from_true, select_from_false, res] |
| 18 | + # V v_final_and_scan_outputs = Loop(int64 M, B cond, V v_initial) |
| 19 | + utils.make_sure(len(node.input) > 1, "Select with only condition is not supported.") |
| 20 | + |
| 21 | + nodes = [] |
| 22 | + true_data_type = ctx.get_dtype(node.input[1]) |
| 23 | + true_data_shape = ctx.get_shape(node.input[1]) |
| 24 | + make_sure(true_data_type is not None, "select true data dtype cannot be None") |
| 25 | + make_sure(true_data_shape is not None, "select true data shape cannot be None") |
| 26 | + |
| 27 | + condition_shape = ctx.get_shape(node.input[0]) |
| 28 | + utils.make_sure(condition_shape is not None, "condition shape is None") |
| 29 | + rank = len(condition_shape) |
| 30 | + |
| 31 | + utils.make_sure(rank >= 0, "rank should be >= 0") |
| 32 | + val_output_id = None |
| 33 | + if rank > 0: |
| 34 | + # create nodes getting shape of condition |
| 35 | + shape_node_output_shape = [rank] |
| 36 | + shape_node = ctx.make_node("Shape", [node.input[0]], op_name_scope=node.name, |
| 37 | + shapes=[shape_node_output_shape], dtypes=[TensorProto.INT64]) |
| 38 | + nodes.append(shape_node) |
| 39 | + |
| 40 | + # todo(pengwa), move those leveraging rewrite_incomplete_type_support_onnxruntime after shape inferencing |
| 41 | + # bug is fixed. |
| 42 | + # workaround: onnxruntime does not support Split-2, add cases before and after. |
| 43 | + target_dtype = TensorProto.FLOAT |
| 44 | + shape_f_node = ctx.make_node("Cast", [shape_node.output[0]], attr={"to": target_dtype}, |
| 45 | + shapes=[shape_node_output_shape], dtypes=[target_dtype], |
| 46 | + op_name_scope=node.name) |
| 47 | + nodes.append(shape_f_node) |
| 48 | + |
| 49 | + split_attr = [1 for i in range(rank)] |
| 50 | + output_shapes = [[1] for i in range(rank)] |
| 51 | + output_dtypes = [target_dtype for i in range(rank)] |
| 52 | + split_node = ctx.make_node("Split", [shape_f_node.output[0]], output_count=rank, |
| 53 | + attr={"split": split_attr}, shapes=output_shapes, |
| 54 | + dtypes=output_dtypes, op_name_scope=node.name) |
| 55 | + nodes.append(split_node) |
| 56 | + |
| 57 | + trip_cnts = [] |
| 58 | + for i in range(rank): |
| 59 | + output_id = split_node.output[i] |
| 60 | + output_shape = ctx.get_shape(output_id) |
| 61 | + target_dtype = TensorProto.INT64 |
| 62 | + shape_i_node = ctx.make_node("Cast", [output_id], attr={"to": target_dtype}, |
| 63 | + shapes=[output_shape], dtypes=[target_dtype], |
| 64 | + op_name_scope=node.name) |
| 65 | + trip_cnts.append(shape_i_node.output[0]) |
| 66 | + nodes.append(shape_i_node) |
| 67 | + # workaround ends |
| 68 | + |
| 69 | + onnx_nodes = create_loop_op(ctx, node.input, true_data_type, true_data_shape, trip_cnts, rank) |
| 70 | + nodes.extend(onnx_nodes) |
| 71 | + loop_node = onnx_nodes[-1] |
| 72 | + |
| 73 | + val_output_id = loop_node.output[1] |
| 74 | + elif rank == 0: |
| 75 | + if_node, val_output_id = create_if_op(ctx, node.input, true_data_type, true_data_shape) |
| 76 | + nodes.append(if_node) |
| 77 | + |
| 78 | + ctx.copy_shape(node.output[0], val_output_id) |
| 79 | + ctx.set_dtype(node.output[0], true_data_type) |
| 80 | + |
| 81 | + output_node = ctx.make_node("Identity", [val_output_id], outputs=node.output, |
| 82 | + shapes=[ctx.get_shape(val_output_id)], dtypes=[true_data_type]) |
| 83 | + nodes.append(output_node) |
| 84 | + |
| 85 | + return nodes |
| 86 | + |
| 87 | + |
| 88 | +# gather_input_ids is 1-D tensor, containing 3 elements: |
| 89 | +# 0: condition data to gather on |
| 90 | +# 1: true result to gather on |
| 91 | +# 2: false result to father on |
| 92 | +def create_loop_op(g, gather_input_ids, output_type, output_shape, trip_count_input_ids, rank): |
| 93 | + nodes = [] |
| 94 | + cond_var_name = utils.make_name("cond_var") |
| 95 | + nodes.append(g.make_const(cond_var_name, np.array(True, dtype=np.bool))) |
| 96 | + |
| 97 | + # Loop requires at least a variable, add a useless fake variable. |
| 98 | + fake_val_name = utils.make_name("fake_var") |
| 99 | + nodes.append(g.make_const(fake_val_name, np.array(0.0, dtype=np.float32))) |
| 100 | + |
| 101 | + if rank < 1: |
| 102 | + raise ValueError("rank is < 1") |
| 103 | + trip_count_input_id = trip_count_input_ids[-1 * rank] |
| 104 | + |
| 105 | + loop_inputs = [trip_count_input_id, # trip count |
| 106 | + cond_var_name, # termination condition |
| 107 | + fake_val_name # initial value of loop-carried dependencies |
| 108 | + ] |
| 109 | + # define an extra scan output |
| 110 | + loop_node = g.make_node("Loop", loop_inputs, output_count=2, op_name_scope="select_loop", |
| 111 | + skip_conversion=False) |
| 112 | + loop_body = create_loop_body_graph(g, gather_input_ids, output_type, output_shape, trip_count_input_ids, |
| 113 | + rank, loop_node.name) |
| 114 | + loop_node.set_body_graph_as_attr("body", loop_body) |
| 115 | + nodes.append(loop_node) |
| 116 | + return nodes |
| 117 | + |
| 118 | + |
| 119 | +def get_inputs_for_current_iteration(g, input_id, iter_index): |
| 120 | + nodes = [] |
| 121 | + cond_gather_node = g.make_node("Gather", [input_id, iter_index]) |
| 122 | + nodes.append(cond_gather_node) |
| 123 | + |
| 124 | + cur_cond_val_scalar_node = g.make_node("Squeeze", [cond_gather_node.output[0]], attr={"axes": [0]}) |
| 125 | + nodes.append(cur_cond_val_scalar_node) |
| 126 | + |
| 127 | + return nodes, cur_cond_val_scalar_node.output[0] |
| 128 | + |
| 129 | + |
| 130 | +def create_loop_body_graph(parent_g, gather_input_ids, output_data_type, output_shape, trip_count_input_ids, |
| 131 | + rank, loop_name): |
| 132 | + g = parent_g.create_new_graph_with_same_config() |
| 133 | + iter_name = utils.make_name("i") |
| 134 | + cond_name = utils.make_name("cond") |
| 135 | + fake_var_name = utils.make_name("fake_var") |
| 136 | + |
| 137 | + g.add_graph_input(iter_name, TensorProto.INT64, (1,)) # iteration_num |
| 138 | + g.add_graph_input(cond_name, TensorProto.BOOL, ()) # condition |
| 139 | + g.add_graph_input(fake_var_name, TensorProto.FLOAT, ()) # loop-carried dependency |
| 140 | + nodes = g.get_nodes() |
| 141 | + # get the i'th value of condition |
| 142 | + cond_input_id = gather_input_ids[0] |
| 143 | + new_nodes, cond_input_id_for_current_iter = get_inputs_for_current_iteration(g, cond_input_id, iter_name) |
| 144 | + nodes.extend(new_nodes) |
| 145 | + |
| 146 | + # get the i'th value of true values |
| 147 | + true_input_id = gather_input_ids[1] |
| 148 | + new_nodes, true_input_id_for_current_iter = get_inputs_for_current_iteration(g, true_input_id, iter_name) |
| 149 | + nodes.extend(new_nodes) |
| 150 | + |
| 151 | + |
| 152 | + # get the i'th value of false values |
| 153 | + false_input_id = gather_input_ids[2] |
| 154 | + new_nodes, false_input_id_for_current_iter = get_inputs_for_current_iteration(g, false_input_id, iter_name) |
| 155 | + nodes.extend(new_nodes) |
| 156 | + |
| 157 | + input_ids_for_current_iter = [cond_input_id_for_current_iter, true_input_id_for_current_iter, |
| 158 | + false_input_id_for_current_iter] |
| 159 | + output_id = None |
| 160 | + rank = rank - 1 |
| 161 | + if rank >= 1: |
| 162 | + nodes_1 = create_loop_op(g, input_ids_for_current_iter, output_data_type, output_shape[1:], |
| 163 | + trip_count_input_ids, rank) |
| 164 | + loop_1 = nodes_1[-1] |
| 165 | + output_id = loop_1.output[1] |
| 166 | + nodes.extend(nodes_1) |
| 167 | + elif rank == 0: |
| 168 | + if_node, if_node_output_id = create_if_op(g, input_ids_for_current_iter, output_data_type, output_shape[1:]) |
| 169 | + output_id = if_node_output_id |
| 170 | + nodes.append(if_node) |
| 171 | + |
| 172 | + output_identity_name = utils.make_name("loop_output") |
| 173 | + loop_output_id = utils.port_name(output_identity_name) |
| 174 | + loop_output_node = g.make_node( |
| 175 | + 'Identity', |
| 176 | + [output_id], |
| 177 | + outputs=[loop_output_id], |
| 178 | + name=output_identity_name |
| 179 | + ) |
| 180 | + nodes.append(loop_output_node) |
| 181 | + |
| 182 | + cond_identity_name = utils.make_name("cond_output") |
| 183 | + cond_output_id = utils.port_name(cond_identity_name) |
| 184 | + identity_node = g.make_node( |
| 185 | + 'Identity', |
| 186 | + [cond_name], |
| 187 | + outputs=[cond_output_id], |
| 188 | + name=cond_identity_name |
| 189 | + ) |
| 190 | + nodes.append(identity_node) |
| 191 | + |
| 192 | + fake_var_identity_name = utils.make_name("fake_var_output") |
| 193 | + fake_var_output_id = utils.port_name(fake_var_identity_name) |
| 194 | + identity_node = g.make_node( |
| 195 | + 'Identity', |
| 196 | + [fake_var_name], |
| 197 | + outputs=[fake_var_output_id], |
| 198 | + name=fake_var_identity_name |
| 199 | + ) |
| 200 | + nodes.append(identity_node) |
| 201 | + |
| 202 | + g.set_nodes(nodes) |
| 203 | + |
| 204 | + g.add_graph_output(cond_output_id, TensorProto.BOOL, ()) |
| 205 | + g.add_graph_output(fake_var_output_id, TensorProto.FLOAT, ()) |
| 206 | + |
| 207 | + # use None for all dims, just keep original rank. Because it is observed, dims might be changed in loop. |
| 208 | + g.add_graph_output(loop_output_id, output_data_type, utils.create_vague_shape_like(output_shape[1:])) |
| 209 | + |
| 210 | + return g |
| 211 | + |
| 212 | + |
| 213 | +def create_if_op(g, input_ids, output_data_type, output_shape): |
| 214 | + op_name = utils.make_name("If") |
| 215 | + true_graph = create_body_graph_for_if_branch(g, output_data_type, output_shape, input_ids[1], op_name) |
| 216 | + false_graph = create_body_graph_for_if_branch(g, output_data_type, output_shape, input_ids[2], op_name) |
| 217 | + out_name = port_name(op_name) |
| 218 | + |
| 219 | + # output a scalar |
| 220 | + if_node = g.make_node("If", [input_ids[0]], outputs=[out_name], name=op_name, skip_conversion=False) |
| 221 | + if_node.set_body_graph_as_attr("then_branch", true_graph) |
| 222 | + if_node.set_body_graph_as_attr("else_branch", false_graph) |
| 223 | + return if_node, out_name |
| 224 | + |
| 225 | + |
| 226 | +def create_body_graph_for_if_branch(parent_g, data_type, output_shape, chosen_cur_cond_val_out_name, op_name): |
| 227 | + g = parent_g.create_new_graph_with_same_config() |
| 228 | + nodes = [] |
| 229 | + name = utils.make_name("Identity") |
| 230 | + identity_node = g.make_node( |
| 231 | + 'Identity', |
| 232 | + inputs=[chosen_cur_cond_val_out_name], |
| 233 | + outputs=['y'], |
| 234 | + name=name |
| 235 | + ) |
| 236 | + nodes.append(identity_node) |
| 237 | + g.set_nodes(nodes) |
| 238 | + g.add_graph_output("y", data_type, utils.create_vague_shape_like(output_shape)) |
| 239 | + return g |
0 commit comments