Skip to content

Commit ee2b202

Browse files
authored
Merge pull request #1056 from xadupre/input
Use the same syntax to replace an node input
2 parents d3d301a + e6d4bcb commit ee2b202

File tree

14 files changed

+117
-88
lines changed

14 files changed

+117
-88
lines changed

tf2onnx/custom_opsets/ms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ def version_1(cls, ctx, node, **kwargs):
8585
# set node's attrs, Note: output_padding, group are left default.
8686
conv_dims_attr(node, "dilations")
8787
# set node's inputs from (output_shape, filter, input_tensor) to (input_tensor, filter, pads, Bias)
88-
node.input[0] = node.input[2]
89-
node.input[2] = pads.output[0]
88+
ctx.replace_input(node, node.input[0], node.input[2], 0)
89+
ctx.replace_input(node, node.input[2], pads.output[0], 2)
9090
conv_convert_inputs(ctx, node, with_kernel=True)
9191
node.attr.pop("data_format")
9292
node.attr.pop("padding")

tf2onnx/graph.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ def input(self):
5656

5757
@input.setter
5858
def input(self, val):
59+
# The setter can catch that all inputs are change
60+
# but it cannot catch that one input is changed.
61+
# That's method replace_input and replace_inputs must
62+
# be used to change inputs to let the graph instance
63+
# update its internal indices.
5964
self._input = copy.deepcopy(val)
6065

6166
@property
@@ -1134,13 +1139,21 @@ def dump_node_statistics(self):
11341139
return op_cnt
11351140

11361141
@staticmethod
1137-
def remove_input(node, to_be_removed):
1142+
def remove_input(node, to_be_removed, input_index=None):
11381143
"""Remove input from Node.
11391144
Args:
11401145
node: the node we expect the input on
11411146
to_be_removed: the node name we want to remove
1147+
input_index: if not None, index of the input to be removed,
1148+
the method is more efficient if *input_index* is specified,
1149+
otherwise, it has to look for every input named *old_input*.
11421150
"""
11431151
assert isinstance(node, Node) and isinstance(to_be_removed, six.text_type)
1152+
if input_index is not None:
1153+
assert node.input[input_index] == to_be_removed
1154+
del node.input[input_index]
1155+
return True
1156+
11441157
for i, name in enumerate(node.input):
11451158
if name == to_be_removed:
11461159
del node.input[i]
@@ -1171,7 +1184,7 @@ def insert_new_node_on_input(self, node, op_type, input_name, name=None, domain=
11711184
new_node = self.make_node(op_type, input_name, attr=kwargs, outputs=[new_output], name=name, domain=domain)
11721185
for i, n in enumerate(node.input):
11731186
if n == input_name[0]:
1174-
node.input[i] = new_output
1187+
self.replace_input(node, node.input[i], new_output, i)
11751188
break
11761189
return new_node
11771190

@@ -1232,17 +1245,32 @@ def replace_all_inputs(ops, old_input, new_input):
12321245
for g in body_graphs.values():
12331246
g.replace_all_inputs(g.get_nodes(), old_input, new_input)
12341247

1235-
@staticmethod
1236-
def replace_input(node, old_input, new_input):
1237-
"""Replace node."""
1248+
def replace_input(self, node, old_input, new_input, input_index=None):
1249+
"""
1250+
Replace one input in a node.
1251+
The method is more efficient if *input_index* is specified.
1252+
Otherwise, it renames every output named *old_input*.
1253+
"""
12381254
assert isinstance(node, Node) and isinstance(old_input, six.text_type) and isinstance(new_input, six.text_type)
12391255
is_replaced = False
1240-
for i, input_name in enumerate(node.input):
1241-
if input_name == old_input:
1242-
node.input[i] = new_input
1243-
is_replaced = True
1256+
if input_index is None:
1257+
for i, input_name in enumerate(node.input):
1258+
if input_name == old_input:
1259+
node.input[i] = new_input
1260+
is_replaced = True
1261+
elif node.input[input_index] == old_input:
1262+
node.input[input_index] = new_input
1263+
is_replaced = True
1264+
else:
1265+
raise RuntimeError("Unable to replace input %r into %r for node %r." % (old_input, new_input, node.name))
12441266
return is_replaced
12451267

1268+
def replace_inputs(self, node, new_inputs):
1269+
"""Replace node inputs."""
1270+
assert isinstance(node, Node) and isinstance(new_inputs, list)
1271+
node.input = new_inputs
1272+
return True
1273+
12461274
def _extract_sub_graph_nodes(self, dest_node, input_checker=None):
12471275
"""Return nodes of subgraph ending with dest_node.
12481276
Args:

tf2onnx/onnx_opset/common.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ def version_1(cls, ctx, node, **kwargs):
4040
shape1 = node.inputs[1].scalar_to_dim1()
4141
if shape0 and shape1 and len(shape0) < len(shape1) and node.type in ["Mul", "Add"]:
4242
tmp = node.input[0]
43-
node.input[0] = node.input[1]
44-
node.input[1] = tmp
43+
ctx.replace_input(node, node.input[0], node.input[1], 0)
44+
ctx.replace_input(node, node.input[1], tmp, 1)
4545
else:
4646
node.set_attr("broadcast", 0)
4747

@@ -65,5 +65,5 @@ def version_6(cls, ctx, node, **kwargs):
6565
shape1 = node.inputs[1].scalar_to_dim1()
6666
if shape0 and shape1 and len(shape0) < len(shape1) and node.type in ["Mul", "Add"]:
6767
tmp = node.input[0]
68-
node.input[0] = node.input[1]
69-
node.input[1] = tmp
68+
ctx.replace_input(node, node.input[0], node.input[1], 0)
69+
ctx.replace_input(node, node.input[1], tmp, 1)

tf2onnx/onnx_opset/controlflow.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def version_9(cls, ctx, node, **kwargs):
374374
broadcast_shape = [cond_shape[0]] + [1] * (input_rank - 1)
375375
shape_const = ctx.make_const(utils.make_name(node.name), np.array(broadcast_shape, dtype=np.int64))
376376
reshape = ctx.make_node("Reshape", [node.input[0], shape_const.output[0]])
377-
ctx.replace_input(node, node.input[0], reshape.output[0])
377+
ctx.replace_input(node, node.input[0], reshape.output[0], 0)
378378

379379

380380
@tf_op("Where")
@@ -456,7 +456,7 @@ class TensorListGetItem:
456456
def version_7(cls, ctx, node, **kwargs):
457457
ctx.ta_reads.append(node.input[0])
458458
node.type = "Gather"
459-
node.input = [node.input[0], node.input[1]]
459+
ctx.replace_inputs(node, [node.input[0], node.input[1]])
460460
ctx.insert_new_node_on_input(node, "Unsqueeze", node.input[1], name=node.child_name(), axes=[0])
461461
ctx.insert_new_node_on_output("Squeeze", node.output[0], name=node.child_name(), axes=[0])
462462

@@ -531,7 +531,7 @@ def version_7(cls, ctx, node, **kwargs):
531531
else:
532532
maximum_iterations_name = utils.make_name(node.inputs[1].name)
533533
ctx.make_const(maximum_iterations_name, np.array(maximum_iterations, dtype=np.int64))
534-
node.input[1] = maximum_iterations_name
534+
ctx.replace_input(node, node.input[1], maximum_iterations_name, 1)
535535

536536
cond_name = node.get_attr_str("cond")
537537
cond_graph = find_function(cond_name)
@@ -642,7 +642,7 @@ def wire_while_body(parent_g, g, loop_node_inputs, body_input_to_state_var, cond
642642
for n in g.inputs:
643643
if n.output[0] in body_input_to_state_var:
644644
n.type = "Identity"
645-
n.input = [body_input_to_state_var[n.output[0]]]
645+
g.replace_inputs(n, [body_input_to_state_var[n.output[0]]])
646646

647647
# onnx will pass in cond as argument
648648
cond_node = g.make_node("Placeholder", [], name=utils.make_name("cond"),
@@ -673,7 +673,7 @@ def wire_while_body(parent_g, g, loop_node_inputs, body_input_to_state_var, cond
673673
node.type = "Identity"
674674
g.set_shape(node.output[0], g.get_shape(node.input[2]))
675675
g.set_dtype(node.output[0], g.get_dtype(node.input[2]))
676-
node.input = [node.input[2]]
676+
g.replace_inputs(node, [node.input[2]])
677677
scan_outputs.append(node.output[0])
678678

679679
if len(scan_outputs) != len(removed_scan_outputs):
@@ -729,7 +729,7 @@ def wire_if_branch(parent_g, g, inputs, output_shapes, output_dtypes, scope, par
729729
for node in g.inputs:
730730
parent_name = binding.get(node.output[0])
731731
if parent_name and parent_name != "@@ALLOC":
732-
node.input = [parent_name]
732+
g.replace_inputs(node, [parent_name])
733733
node.type = "Identity"
734734
else:
735735
to_remove.append(node)
@@ -753,7 +753,7 @@ def inline_subgraph(parent, g, scope, binding):
753753
for node in g.inputs:
754754
parent_name = binding.get(node.output[0])
755755
if parent_name and parent_name != "@@ALLOC":
756-
node.input = [parent_name]
756+
g.replace_inputs(node, [parent_name])
757757
node.type = "Identity"
758758
else:
759759
to_remove.append(node)

tf2onnx/onnx_opset/generator.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def version_1(cls, ctx, node, **kwargs):
4040
node.set_attr("seed", float(seed.f))
4141
if len(node.input) > 0:
4242
shape = node.inputs[0].get_tensor_value()
43-
ctx.remove_input(node, node.input[0])
43+
ctx.remove_input(node, node.input[0], 0)
4444
node.set_attr("shape", shape)
4545
ctx.set_shape(node.output[0], shape)
4646

@@ -53,7 +53,7 @@ def version_9(cls, ctx, node, **kwargs):
5353
node.set_attr("seed", float(seed.f))
5454
cast_node = ctx.make_node("Cast", node.input, attr={'to': onnx_pb.TensorProto.INT64})
5555
const_node = ctx.make_node("ConstantOfShape", cast_node.output)
56-
node.input = const_node.output
56+
ctx.replace_inputs(node, const_node.output.copy())
5757
node.type = node.type + 'Like'
5858

5959

@@ -103,8 +103,8 @@ def version_7(cls, ctx, node, **kwargs):
103103
ctx.set_shape(tile_shape_int64.output[0], fill_shape)
104104

105105
tmp = node.input[0]
106-
node.input[0] = node.input[1]
107-
node.input[1] = tmp
106+
ctx.replace_input(node, node.input[0], node.input[1], 0)
107+
ctx.replace_input(node, node.input[1], tmp, 1)
108108
node.type = "Tile"
109109
ctx.set_dtype(node.output[0], new_dtype)
110110

@@ -128,13 +128,13 @@ def version_9(cls, ctx, node, **kwargs):
128128
value = np.array([node.inputs[1].get_tensor_value()]).astype(utils.map_onnx_to_numpy_type(dtype))
129129
value_proto = numpy_helper.from_array(value)
130130
node.set_attr("value", value_proto)
131-
del node.input[1]
131+
ctx.remove_input(node, node.input[1], 1)
132132

133133
@classmethod
134134
def version_11(cls, ctx, node, **kwargs):
135135
# cls.version_7(ctx, node, **kwargs)
136136
node.type = "Expand"
137-
node.input = [node.input[1], node.input[0]]
137+
ctx.replace_inputs(node, [node.input[1], node.input[0]])
138138
# cast shape to int64 if needed
139139
if ctx.get_dtype(node.input[1]) != onnx_pb.TensorProto.INT64:
140140
ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=onnx_pb.TensorProto.INT64)
@@ -156,7 +156,7 @@ def version_7(cls, ctx, node, **kwargs):
156156
output_dtype = onnx_pb.TensorProto.INT32
157157
node.set_attr("dtype", output_dtype)
158158
node.set_attr("sample_size", sample_size)
159-
ctx.remove_input(node, node.input[1])
159+
ctx.remove_input(node, node.input[1], 1)
160160

161161

162162
@tf_op("ZerosLike")

tf2onnx/onnx_opset/math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def make_min_or_max_op(ctx, op_type, inputs, outputs,
123123
# use add as 'broadcast' op
124124
add_node = ctx.make_node("Add", [input_node.output[0], sub_node.output[0]],
125125
op_name_scope=input_node.name)
126-
node.input[i] = add_node.output[0]
126+
ctx.replace_input(node, node.input[i], add_node.output[0], i)
127127
return final_node
128128

129129

@@ -293,7 +293,7 @@ def version_1(cls, ctx, node, **kwargs):
293293
# workaround a bug in caffe2 pre Feb2018, pow(a, b) becomes np.exp(np.log(a) * b)
294294
node.type = "Log"
295295
b = node.input[1]
296-
ctx.remove_input(node, node.input[1])
296+
ctx.remove_input(node, node.input[1], 1)
297297
op_name = utils.make_name(node.name)
298298
mul_op = ctx.insert_new_node_on_output("Mul", node.output[0], name=op_name)
299299
mul_op.input.append(b)

tf2onnx/onnx_opset/nn.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
130130
ctx.make_const(shape_name, np.array(new_kernel_shape, dtype=np.int64))
131131

132132
reshape = ctx.make_node("Reshape", [kernel_name, shape_name])
133-
ctx.replace_input(node, kernel_name, reshape.output[0])
133+
ctx.replace_input(node, kernel_name, reshape.output[0], 1)
134134

135135
reshape.skip_conversion = True
136136

@@ -453,11 +453,11 @@ def version_1(cls, ctx, node, **kwargs):
453453
conv_dims_attr(node, "dilations", spatial=spatial)
454454

455455
# remove output_shapes input
456-
ctx.remove_input(node, node.input[0])
456+
ctx.remove_input(node, node.input[0], 0)
457457
# swap data and kernel
458458
t = node.input[0]
459-
node.input[0] = node.input[1]
460-
node.input[1] = t
459+
ctx.replace_input(node, node.input[0], node.input[1], 0)
460+
ctx.replace_input(node, node.input[1], t, 1)
461461

462462
conv_convert_inputs(ctx, node, with_kernel=True, spatial=spatial)
463463

@@ -539,8 +539,8 @@ def _convert(cls, ctx, node, **kwargs):
539539
else:
540540
kernel_shape_tf = node.inputs[1].get_tensor_value()
541541
strides_tf = node.inputs[2].get_tensor_value()
542-
ctx.remove_input(node, node.input[2])
543-
ctx.remove_input(node, node.input[1])
542+
ctx.remove_input(node, node.input[2], 2)
543+
ctx.remove_input(node, node.input[1], 1)
544544

545545
kernel_shape_hw = parse_dims_attr(node, kernel_shape_tf, spatial)
546546
strides_hw = parse_dims_attr(node, strides_tf, spatial)
@@ -605,7 +605,7 @@ def version_7(cls, ctx, node, **kwargs):
605605
ctx.make_const(shape_name, np.array(new_broadcast_shape, dtype=np.int64))
606606
op_name = node.input[1]
607607
reshape_node = ctx.make_node("Reshape", [op_name, shape_name])
608-
ctx.replace_input(node, op_name, reshape_node.output[0])
608+
ctx.replace_input(node, op_name, reshape_node.output[0], 1)
609609
ctx.set_shape(reshape_node.output[0], new_broadcast_shape)
610610

611611

@@ -629,9 +629,9 @@ def version_1(cls, ctx, node, **kwargs):
629629
if mode in [None, "constant"] and len(node.input) == 3:
630630
const_val = node.inputs[2].get_tensor_value()
631631
node.set_attr("value", const_val)
632-
ctx.remove_input(node, node.input[2])
632+
ctx.remove_input(node, node.input[2], 2)
633633

634-
ctx.remove_input(node, node.input[1])
634+
ctx.remove_input(node, node.input[1], 1)
635635
node.set_attr("pads", paddings)
636636

637637
origin_dtype = ctx.get_dtype(node.output[0])
@@ -707,14 +707,14 @@ def version_6(cls, ctx, node, **kwargs):
707707
dtype=val_type)
708708
new_mean_node_name = utils.make_name(node.name)
709709
ctx.make_const(new_mean_node_name, new_mean_value)
710-
node.input[3] = new_mean_node_name
710+
ctx.replace_input(node, node.input[3], new_mean_node_name, 3)
711711

712712
if var_shape != scale_shape:
713713
new_var_value = np.array(np.resize(node.inputs[4].get_tensor_value(as_list=False), scale_shape),
714714
dtype=val_type)
715715
new_val_node_name = utils.make_name(node.name)
716716
ctx.make_const(new_val_node_name, new_var_value)
717-
node.input[4] = new_val_node_name
717+
ctx.replace_input(node, node.input[4], new_val_node_name, 4)
718718

719719
@classmethod
720720
def version_9(cls, ctx, node, **kwargs):
@@ -865,7 +865,7 @@ def version_7(cls, ctx, node, **kwargs):
865865
scaler = [1., 1., float(nh) / h, float(nw) / w]
866866
node.set_attr("scales", scaler)
867867
node.set_attr("mode", mode)
868-
ctx.remove_input(node, node.input[1])
868+
ctx.remove_input(node, node.input[1], 1)
869869
node.data_format = "NHWC"
870870
conv_convert_inputs(ctx, node, with_kernel=False)
871871

tf2onnx/onnx_opset/quantize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def version_10(cls, ctx, node, **kwargs):
7070
op_name_scope=node.name, attr={"axis": axis},
7171
shapes=[shape], dtypes=[idtype])
7272
output_name = new_node.output[0]
73-
node.input[0] = output_name
73+
ctx.replace_input(node, node.input[0], output_name, 0)
7474

7575
ctx.remove_node(node.name)
7676

tf2onnx/onnx_opset/reduction.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def version_1(cls, ctx, node, **kwargs):
4343
axes = [val + input_rank if val < 0 else val for val in axes]
4444

4545
node.set_attr("axes", axes)
46-
ctx.remove_input(node, node.input[1])
46+
ctx.remove_input(node, node.input[1], 1)
4747
keep_dims = node.get_attr("keep_dims")
4848
if keep_dims:
4949
del node.attr['keep_dims']
@@ -82,7 +82,7 @@ def version_1(cls, ctx, node, **kwargs):
8282

8383
node.set_attr("axis", axis)
8484
node.set_attr("keepdims", 0)
85-
ctx.remove_input(node, node.input[1])
85+
ctx.remove_input(node, node.input[1], 1)
8686

8787
@classmethod
8888
def version_11(cls, ctx, node, **kwargs):

0 commit comments

Comments
 (0)