Skip to content

Commit 9ad9f50

Browse files
author
wayuanho
committed
fix bugs of strided_slice
1 parent da2bfbd commit 9ad9f50

File tree

2 files changed

+37
-6
lines changed

2 files changed

+37
-6
lines changed

tests/test_backend.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1230,6 +1230,28 @@ def test_strided_slice6(self):
12301230
_ = tf.identity(x_, name=_TFOUTPUT)
12311231
self._run_test_case([_OUTPUT], {_INPUT: x_val})
12321232

1233+
@unittest.skipIf(BACKEND in ["caffe2"], "multiple dims not supported")
1234+
def test_strided_slice7(self):
1235+
x_val = np.arange(5*6).astype("float32").reshape(5, 6)
1236+
1237+
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
1238+
x_ = tf.strided_slice(x, [0, 1], [3, 4], [1, 1], begin_mask=2)
1239+
_ = tf.identity(x_, name=_TFOUTPUT)
1240+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1241+
1242+
tf.reset_default_graph()
1243+
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
1244+
x_ = tf.strided_slice(x, [0, 1], [3, 4], [1, 1], end_mask=2)
1245+
_ = tf.identity(x_, name=_TFOUTPUT)
1246+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1247+
1248+
tf.reset_default_graph()
1249+
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
1250+
x_ = tf.strided_slice(x, [0, 1], [3, 4], [1, 1], shrink_axis_mask=2)
1251+
_ = tf.identity(x_, name=_TFOUTPUT)
1252+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1253+
1254+
12331255
@unittest.skipIf(BACKEND in ["caffe2"], "fails with schema error")
12341256
@unittest.skipIf(*support_op_conversion_since(7, "batchnorm"))
12351257
def test_batchnorm(self):

tf2onnx/tfonnx.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -988,11 +988,14 @@ def stridedslice_op(ctx, node, name, args):
988988
if attr is not None and attr.i != 0:
989989
raise ValueError("StridedSlice: attribute " + attr_name + " not supported")
990990
input_shape = ctx.get_shape(node.input[0])
991-
begin = node.inputs[1].get_tensor_value()
992-
end = node.inputs[2].get_tensor_value()
993-
strides = node.inputs[3].get_tensor_value()
991+
begin = node.inputs[1].get_tensor_value(as_list=False)
992+
end = node.inputs[2].get_tensor_value(as_list=False)
993+
strides = node.inputs[3].get_tensor_value(as_list=False)
994+
max_size = np.iinfo(begin.dtype).max
994995
end_mask = node.get_attr("end_mask")
995996
end_mask = end_mask.i if end_mask is not None else 0
997+
begin_mask = node.get_attr("begin_mask")
998+
begin_mask = begin_mask.i if begin_mask is not None else 0
996999
shrink_axis_mask = node.get_attr("shrink_axis_mask")
9971000
shrink_axis_mask = shrink_axis_mask.i if shrink_axis_mask is not None else 0
9981001
new_begin = []
@@ -1008,19 +1011,25 @@ def stridedslice_op(ctx, node, name, args):
10081011

10091012
# an implicit condition is stride == 1 (checked in above)
10101013
if begin_item < 0 and end_item == 0:
1011-
end_item = sys.maxsize
1014+
end_item = max_size
10121015

10131016
mask = (shrink_axis_mask >> idx) & 1
10141017
if mask != 0:
10151018
new_begin.append(begin_item)
1019+
end_item = begin_item + 1 if begin_item != -1 else max_size
10161020
new_end.append(end_item)
10171021
needs_squeeze.append(idx)
10181022
continue
10191023

1020-
new_begin.append(begin_item)
1024+
mask = (begin_mask >> idx) & 1
1025+
if mask != 0:
1026+
new_begin.append(0)
1027+
else:
1028+
new_begin.append(begin_item)
1029+
10211030
mask = (end_mask >> idx) & 1
10221031
if mask != 0:
1023-
new_end.append(sys.maxsize)
1032+
new_end.append(max_size)
10241033
else:
10251034
new_end.append(end_item)
10261035

0 commit comments

Comments
 (0)