Skip to content

Commit 82f805f

Browse files
authored
Merge pull request #429 from lucienwang1009/slice_bug
fix Slice bug
2 parents 9ffbbfd + f5b8d2a commit 82f805f

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

tf2onnx/tfonnx.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -797,8 +797,16 @@ def slice_op(ctx, node, name, args):
797797
# T output = Slice(T input, Index begin, Index size, @type Index)
798798
# T output = Slice(T data, @INTS axes, @INTS ends, @INTS starts)
799799
starts = node.inputs[1].get_tensor_value()
800-
size = node.inputs[2].get_tensor_value()
801-
ends = np.add(starts, size)
800+
sizes = node.inputs[2].get_tensor_value()
801+
ends = []
802+
for start, size in zip(starts, sizes):
803+
# get all elements
804+
if size == -1:
805+
dtype = ctx.get_dtype(node.input[1])
806+
utils.make_sure(dtype, "dtype of {} is None".format(node.input[1]))
807+
ends.append(np.iinfo(dtype).max)
808+
else:
809+
ends.append(start + size)
802810
ctx.remove_input(node, node.input[2])
803811
ctx.remove_input(node, node.input[1])
804812
node.set_attr("starts", starts)

0 commit comments

Comments
 (0)