Skip to content

Commit 96a040b

Browse files
Fix Slice conversion for nodes with multiple outputs (#1473)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 94d8113 commit 96a040b

File tree

2 files changed

+39
-37
lines changed

2 files changed

+39
-37
lines changed

tf2onnx/graph.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -868,6 +868,12 @@ def set_node_by_name(self, node):
868868
for name in node.input:
869869
self._register_input_name(name, node)
870870

871+
def is_const(self, output):
872+
return self.get_node_by_output(output).is_const()
873+
874+
def get_tensor_value(self, output, as_list=True):
875+
return self.get_node_by_output(output).get_tensor_value(as_list)
876+
871877
def change_node_name(self, node, new_name):
872878
"""Remove node in current graph."""
873879
utils.make_sure(new_name not in self._nodes_by_name, "node %s not unique ", new_name)

tf2onnx/onnx_opset/tensor.py

Lines changed: 33 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -896,15 +896,15 @@ def any_version_after10(cls, opset, ctx, node, **kwargs):
896896
# @int shrink_axis_mask, @int new_axis_mask)
897897
# T output = Slice(T input, Tind starts, Tind ends, Tind axes, Tind steps)
898898
# "ends" are exclusive, "axes" and "steps" are optional, their default val are [0, ...] and 1
899-
input_x = node.inputs[0]
900-
begin = node.inputs[1]
901-
end = node.inputs[2]
902-
strides = node.inputs[3]
899+
input_x = node.input[0]
900+
begin = node.input[1]
901+
end = node.input[2]
902+
strides = node.input[3]
903903
new_axis_mask = node.get_attr("new_axis_mask")
904904
new_axis_mask = new_axis_mask.i if new_axis_mask is not None else 0
905905

906-
if begin.is_const() and end.is_const() and strides.is_const() \
907-
and all(val == 1 for val in strides.get_tensor_value()) \
906+
if ctx.is_const(begin) and ctx.is_const(end) and ctx.is_const(strides) \
907+
and all(val == 1 for val in ctx.get_tensor_value(strides)) \
908908
and new_axis_mask == 0:
909909
cls.version_1(ctx, node, **kwargs)
910910
return
@@ -945,7 +945,7 @@ def any_version_after10(cls, opset, ctx, node, **kwargs):
945945
if (new_axis_mask >> bit) & 1 == 1:
946946
num_new += 1
947947
if (ellipsis_mask >> bit) & 1:
948-
input_shape = ctx.get_shape(input_x.output[0])
948+
input_shape = ctx.get_shape(input_x)
949949
# calculate what rank for ellipsis: input rank - (being rank - all new_axis - 1)
950950
ellipsis_gap = len(input_shape) - param_rank + num_new + 1
951951
if (new_axis_mask >> bit) & 1 == 1:
@@ -954,7 +954,7 @@ def any_version_after10(cls, opset, ctx, node, **kwargs):
954954
end_mask |= 1 << bit
955955

956956
input_x = GraphBuilder(ctx).make_unsqueeze(
957-
{'data': input_x.output[0], 'axes': unqueeze_at}, return_node=True)
957+
{'data': input_x, 'axes': unqueeze_at})
958958

959959

960960
# use in onnx graph to mask begin
@@ -969,7 +969,7 @@ def any_version_after10(cls, opset, ctx, node, **kwargs):
969969
ellipsis_gap = 0
970970
for idx in range(param_rank):
971971
if (ellipsis_mask >> idx) & 1:
972-
input_shape = ctx.get_shape(input_x.output[0])
972+
input_shape = ctx.get_shape(input_x)
973973
utils.make_sure(
974974
input_shape is not None,
975975
"StridedSlice op {} requires the shape of input".format(node.name)
@@ -1006,34 +1006,32 @@ def any_version_after10(cls, opset, ctx, node, **kwargs):
10061006
# mask begin
10071007
new_begin_mask = np.array(new_begin_mask, dtype=np_dtype)
10081008
if not np.all(new_begin_mask == 1):
1009-
if begin.is_const() and strides.is_const():
1010-
new_begin_vals = np.copy(begin.get_tensor_value(as_list=False))
1011-
strides_vals = strides.get_tensor_value(as_list=False)
1009+
if ctx.is_const(begin) and ctx.is_const(strides):
1010+
new_begin_vals = np.copy(ctx.get_tensor_value(begin, as_list=False))
1011+
strides_vals = ctx.get_tensor_value(strides, as_list=False)
10121012
idx1 = np.where(new_begin_mask == 0)
10131013
idx2 = np.where(strides_vals < 0)
10141014
idx3 = np.intersect1d(idx1, idx2)
10151015
new_begin_vals[idx3] = max_size
1016-
begin = ctx.make_const(utils.make_name("begin_masked"), new_begin_vals)
1016+
begin = ctx.make_const(utils.make_name("begin_masked"), new_begin_vals).output[0]
10171017
else:
10181018
begin_mask_const = ctx.make_const(utils.make_name("begin_mask"), np.equal(new_begin_mask, 0))
10191019
zero_const = ctx.make_const(utils.make_name("zero_const"), np.zeros(1, dtype=np_dtype))
10201020
max_const = ctx.make_const(utils.make_name("max_const"), np.array(max_size, dtype=np_dtype))
1021-
op1 = ctx.make_node("Less", [strides.output[0], zero_const.output[0]], op_name_scope=node.name)
1021+
op1 = ctx.make_node("Less", [strides, zero_const.output[0]], op_name_scope=node.name)
10221022
op2 = ctx.make_node("And", [op1.output[0], begin_mask_const.output[0]], op_name_scope=node.name)
1023-
begin = ctx.make_node("Where", [op2.output[0], max_const.output[0], begin.output[0]],
1024-
op_name_scope=node.name)
1023+
begin = ctx.make_node("Where", [op2.output[0], max_const.output[0], begin],
1024+
op_name_scope=node.name).output[0]
10251025

10261026
# mask end
10271027
new_end_mask = np.array(new_end_mask, dtype=np_dtype)
1028-
end_output = end.output[0]
10291028
if not np.all(new_end_mask == min_size):
1030-
if end.is_const() and strides.is_const():
1031-
new_end_mask = np.maximum(end.get_tensor_value(as_list=False), new_end_mask)
1029+
if ctx.is_const(end) and ctx.is_const(strides):
1030+
new_end_mask = np.maximum(ctx.get_tensor_value(end, as_list=False), new_end_mask)
10321031
idx = np.where(new_end_mask == max_size)
1033-
sign = np.sign(strides.get_tensor_value(as_list=False))[idx]
1032+
sign = np.sign(ctx.get_tensor_value(strides, as_list=False))[idx]
10341033
new_end_mask[idx] = new_end_mask[idx] * sign
1035-
end = ctx.make_const(utils.make_name("end_masked"), new_end_mask)
1036-
end_output = end.output[0]
1034+
end = ctx.make_const(utils.make_name("end_masked"), new_end_mask).output[0]
10371035
else:
10381036
# Overlay new_end_mask with specified end values.
10391037
# Adjust max_size to min_size if steps are < 0
@@ -1042,25 +1040,22 @@ def any_version_after10(cls, opset, ctx, node, **kwargs):
10421040
zero_const = ctx.make_const(utils.make_name("zero_const"), np.zeros(1, dtype=np_dtype))
10431041
end_mask_const = ctx.make_const(utils.make_name("end_mask"), np.array(new_end_mask, dtype=np_dtype))
10441042
outputname = utils.make_name("{}__newendmask".format(node.name))
1045-
new_end_mask = math.make_min_or_max_op(ctx, "Max", [end.output[0], end_mask_const.output[0]],
1043+
new_end_mask = math.make_min_or_max_op(ctx, "Max", [end, end_mask_const.output[0]],
10461044
[outputname])
1047-
op1 = ctx.make_node("Less", [strides.output[0], zero_const.output[0]], op_name_scope=node.name)
1045+
op1 = ctx.make_node("Less", [strides, zero_const.output[0]], op_name_scope=node.name)
10481046
op2 = ctx.make_node("Equal", [new_end_mask.output[0], max_const.output[0]], op_name_scope=node.name)
10491047
op3 = ctx.make_node("And", [op2.output[0], op1.output[0]], op_name_scope=node.name)
1050-
final_end = ctx.make_node("Where", [op3.output[0], min_const.output[0],
1051-
new_end_mask.output[0]], op_name_scope=node.name)
1052-
end_output = final_end.output[0]
1048+
end = ctx.make_node("Where", [op3.output[0], min_const.output[0], new_end_mask.output[0]],
1049+
op_name_scope=node.name).output[0]
10531050

10541051
# mask strides for shrink
10551052
shrink_strided_mask = np.array(shrink_strided_mask, dtype=np_dtype)
1056-
strides_output = strides.output[0]
10571053
if not np.all(shrink_strided_mask == min_size):
1058-
if strides.is_const():
1054+
if ctx.is_const(strides):
10591055
strides = ctx.make_const(
10601056
utils.make_name("strides_masked"),
1061-
np.maximum(strides.get_tensor_value(as_list=False), shrink_strided_mask)
1062-
)
1063-
strides_output = strides.output[0]
1057+
np.maximum(ctx.get_tensor_value(strides, as_list=False), shrink_strided_mask)
1058+
).output[0]
10641059
else:
10651060
shrink_strided_mask_const = ctx.make_const(
10661061
utils.make_name("strides_mask"),
@@ -1069,9 +1064,10 @@ def any_version_after10(cls, opset, ctx, node, **kwargs):
10691064
strides_output = utils.make_name("{}__strides".format(node.name))
10701065
math.make_min_or_max_op(
10711066
ctx, "Max",
1072-
[strides.output[0], shrink_strided_mask_const.output[0]],
1067+
[strides, shrink_strided_mask_const.output[0]],
10731068
[strides_output]
10741069
)
1070+
strides = strides_output
10751071
# create axes input
10761072
axes_const = ctx.make_const(
10771073
utils.make_name("slice_axes"),
@@ -1080,10 +1076,10 @@ def any_version_after10(cls, opset, ctx, node, **kwargs):
10801076
axes_output = axes_const.output[0]
10811077

10821078
inputs_map = {
1083-
"data": input_x.output[0],
1084-
"starts": begin.output[0],
1085-
"ends": end_output,
1086-
"steps": strides_output,
1079+
"data": input_x,
1080+
"starts": begin,
1081+
"ends": end,
1082+
"steps": strides,
10871083
"axes": axes_output
10881084
}
10891085
kwargs = {**inputs_map, "outputs": node.output}

0 commit comments

Comments
 (0)