Skip to content

Commit 2e642d3

Browse files
Fix common bug where node inputs are assumed to come from output[0] (#1307)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 1e99bf4 commit 2e642d3

File tree

3 files changed

+25
-25
lines changed

3 files changed

+25
-25
lines changed

tf2onnx/custom_opsets/ms.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,10 @@ def version_1(cls, ctx, node, **kwargs):
6363
input_shape = ctx.make_node("Shape", [node.input[2]])
6464
hw_indices = ctx.make_const(utils.make_name("hw_indices"), np.array([1, 2]).astype(np.int64))
6565
input_shape_hw = ctx.make_node("Gather", [input_shape.output[0], hw_indices.output[0]])
66-
output_shape = node.inputs[0]
67-
if ctx.get_dtype(output_shape.output[0]) != onnx_pb.TensorProto.INT64:
68-
output_shape = ctx.make_node("Cast", [output_shape.output[0]], attr={"to": onnx_pb.TensorProto.INT64})
69-
output_shape_hw = ctx.make_node("Gather", [output_shape.output[0], hw_indices.output[0]])
66+
output_shape = node.input[0]
67+
if ctx.get_dtype(output_shape) != onnx_pb.TensorProto.INT64:
68+
output_shape = ctx.make_node("Cast", [output_shape], attr={"to": onnx_pb.TensorProto.INT64}).output[0]
69+
output_shape_hw = ctx.make_node("Gather", [output_shape, hw_indices.output[0]])
7070
kernel_shape_hw = list(ctx.get_shape(node.input[1]))[0:2]
7171
kernel_shape = ctx.make_const(utils.make_name("const_convtrans"), np.array(kernel_shape_hw).astype(np.int64))
7272
strides = conv_dims_attr(node, "strides")

tf2onnx/onnx_opset/math.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -154,25 +154,25 @@ def version_8(cls, ctx, node, **kwargs):
154154
# fetch those upfront since they are not accessible once we remove 'node'
155155
shapes = node.output_shapes
156156
dtypes = node.output_dtypes
157-
input_dtype = node.inputs[0].output_dtypes[0]
157+
input_dtype = ctx.get_dtype(node.input[0])
158158
name = node.name
159-
min_node = node.inputs[1]
160-
if min_node.output_dtypes[0] not in supported:
159+
min_node = node.input[1]
160+
if ctx.get_dtype(min_node) not in supported:
161161
# cast min if needed
162-
min_node = ctx.insert_new_node_on_input(node, "Cast", min_node.output[0], to=onnx_pb.TensorProto.FLOAT)
163-
max_node = node.inputs[2]
164-
if max_node.output_dtypes[0] not in supported:
162+
min_node = ctx.insert_new_node_on_input(node, "Cast", min_node, to=onnx_pb.TensorProto.FLOAT).output[0]
163+
max_node = node.input[2]
164+
if ctx.get_dtype(max_node) not in supported:
165165
# cast max if needed
166-
max_node = ctx.insert_new_node_on_input(node, "Cast", max_node.output[0], to=onnx_pb.TensorProto.FLOAT)
166+
max_node = ctx.insert_new_node_on_input(node, "Cast", max_node, to=onnx_pb.TensorProto.FLOAT).output[0]
167167
ctx.remove_node(name)
168-
new_node = ctx.make_node("Max", [node.input[0], min_node.output[0]], outputs=[node.output[0]],
168+
new_node = ctx.make_node("Max", [node.input[0], min_node], outputs=[node.output[0]],
169169
shapes=shapes, dtypes=dtypes)
170170
if input_dtype not in supported:
171171
# cast the data tensor if needed
172172
ctx.insert_new_node_on_input(new_node, "Cast", new_node.input[0], to=onnx_pb.TensorProto.FLOAT)
173173

174174
new_node = ctx.insert_new_node_on_output("Min", new_node.output[0], name=utils.make_name(name))
175-
new_node.input.append(max_node.output[0])
175+
new_node.input.append(max_node)
176176
# copy shape and type
177177
ctx.set_dtype(new_node.output[0], dtypes[0])
178178
ctx.set_shape(new_node.output[0], shapes[0])

tf2onnx/onnx_opset/nn.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -315,10 +315,10 @@ def build_dynamic_target_size(ctx, transposed_intput, target_hw):
315315
A tensor of rank 2 containing [n c nh nw]
316316
"""
317317
# We get the first half [n c] of the target shape
318-
shape_of_transposed_input = ctx.make_node("Shape", [transposed_intput.output[0]])
318+
shape_of_transposed_input = ctx.make_node("Shape", [transposed_intput])
319319
first_half_of_shape = GraphBuilder(ctx).make_slice(
320320
{"data": shape_of_transposed_input.output[0], "ends": [2], "starts": [0]})
321-
target_size_int64 = ctx.make_node("Cast", [target_hw.output[0]], attr={'to': TensorProto.INT64})
321+
target_size_int64 = ctx.make_node("Cast", [target_hw], attr={'to': TensorProto.INT64})
322322
# We build a tensor containing [n c nh nw]
323323
final_target_size = ctx.make_node("Concat", [first_half_of_shape, target_size_int64.output[0]], {'axis': 0})
324324
return final_target_size
@@ -916,10 +916,10 @@ def any_version_after11(cls, opset, ctx, node, **kwargs):
916916
mode = "nearest" if node.get_attr("method") is not None and node.get_attr(
917917
"method").s == b"nearest" else "linear"
918918
extrapolation_value = float(node.get_attr("extrapolation_value", "0").f)
919-
input_x = node.inputs[0]
920-
boxes = node.inputs[1]
921-
box_ind = node.inputs[2]
922-
crop_size = node.inputs[3]
919+
input_x = node.input[0]
920+
boxes = node.input[1]
921+
box_ind = node.input[2]
922+
crop_size = node.input[3]
923923
trip_name = utils.make_name(node.name + "_i")
924924
cond_name = utils.make_name(node.name + "_cond")
925925
cond_out_name = utils.make_name(node.name + "cond_out")
@@ -932,9 +932,9 @@ def any_version_after11(cls, opset, ctx, node, **kwargs):
932932
const_one = g.make_const(utils.make_name(node.name + "_const_one"), np.array([1], dtype=np.int32))
933933
const_one_long = g.make_const(utils.make_name(node.name + "_const_one_long"), np.array([1], dtype=np.int64))
934934
index_end = g.make_node("Add", [trip_name, const_one_long.output[0]])
935-
box_index_from = g.make_node("Slice", [box_ind.output[0], trip_name, index_end.output[0]], name="Slice_a")
935+
box_index_from = g.make_node("Slice", [box_ind, trip_name, index_end.output[0]], name="Slice_a")
936936
box_index_to = g.make_node("Add", [box_index_from.output[0], const_one.output[0]])
937-
target_x = g.make_node("Slice", [input_x.output[0], box_index_from.output[0], box_index_to.output[0],
937+
target_x = g.make_node("Slice", [input_x, box_index_from.output[0], box_index_to.output[0],
938938
const_zero.output[0]], name="Slice_b")
939939
transposed_x = g.make_node("Transpose", [target_x.output[0]], attr={'perm': constants.NHWC_TO_NCHW})
940940
const_zero_zero = g.make_const(utils.make_name(node.name + "_const_zero_zero"),
@@ -943,15 +943,15 @@ def any_version_after11(cls, opset, ctx, node, **kwargs):
943943
np.array([1, 1], dtype=np.float32))
944944
const_four = g.make_const(utils.make_name(node.name + "_const_four"), np.array([4], dtype=np.int64))
945945
const_empty_float = g.make_const(utils.make_name("const_empty_float"), np.array([], dtype=np.float32))
946-
box = g.make_node("Slice", [boxes.output[0], trip_name, index_end.output[0], const_zero_long.output[0]],
946+
box = g.make_node("Slice", [boxes, trip_name, index_end.output[0], const_zero_long.output[0]],
947947
name="Slice_c")
948948
roi_raw = g.make_node("Reshape", [box.output[0], const_four.output[0]])
949949
roi_raw_first_half = GraphBuilder(g).make_slice({"data": roi_raw.output[0], "ends": [2], "starts": [0]})
950950
roi_raw_second_half = GraphBuilder(g).make_slice({"data": roi_raw.output[0], "ends": [4], "starts": [2]})
951951
roi_concat_1 = g.make_node("Concat", [const_zero_zero.output[0], roi_raw_first_half], attr={'axis': 0})
952952
roi_concat_2 = g.make_node("Concat", [const_one_one.output[0], roi_raw_second_half], attr={'axis': 0})
953953
final_roi = g.make_node("Concat", [roi_concat_1.output[0], roi_concat_2.output[0]], attr={'axis': 0})
954-
final_crop_size = build_dynamic_target_size(g, transposed_x, crop_size)
954+
final_crop_size = build_dynamic_target_size(g, transposed_x.output[0], crop_size)
955955
resized_x = g.make_node("Resize", [transposed_x.output[0], final_roi.output[0], const_empty_float.output[0],
956956
final_crop_size.output[0]],
957957
attr={"mode": mode, "extrapolation_value": extrapolation_value,
@@ -961,7 +961,7 @@ def any_version_after11(cls, opset, ctx, node, **kwargs):
961961
g.make_node("Identity", [cond_name], outputs=[cond_out_name])
962962
g.add_graph_output(cond_out_name, TensorProto.BOOL, [])
963963
g.add_graph_output(squeeze_x.output[0], ctx.get_dtype(node.input[0]), [-1, -1, -1])
964-
trip_node = ctx.make_node("Size", [box_ind.output[0]])
964+
trip_node = ctx.make_node("Size", [box_ind])
965965
cond_const = ctx.make_const(utils.make_name("cond"), np.ones((), dtype=np.bool))
966966
ctx.remove_node(node.name)
967967
branches = {"body": g}
@@ -1070,7 +1070,7 @@ def _convert_since_9(cls, ctx, node, op_type, use_target_size=False):
10701070
# because onnxruntime only supports to scale the last two dims so transpose is inserted
10711071
input_nchw = ctx.make_node("Transpose", [node.input[0]], {"perm": constants.NHWC_TO_NCHW})
10721072
if use_target_size:
1073-
final_target_size = build_dynamic_target_size(ctx, input_nchw, node.inputs[1])
1073+
final_target_size = build_dynamic_target_size(ctx, input_nchw.output[0], node.input[1])
10741074
roi = ctx.make_const(utils.make_name("roi"), np.array([]).astype(np.float32))
10751075
const_empty_float = ctx.make_const(utils.make_name("const_empty_float"), np.array([], dtype=np.float32))
10761076
resize_inputs = [

0 commit comments

Comments
 (0)