Skip to content

Commit 15ac39c

Browse files
authored
Merge pull request #315 from lucienwang1009/strided_slice
refine strided_slice
2 parents 717721a + 9ad9f50 commit 15ac39c

File tree

2 files changed

+37
-6
lines changed

2 files changed

+37
-6
lines changed

tests/test_backend.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1242,6 +1242,28 @@ def test_strided_slice6(self):
12421242
_ = tf.identity(x_, name=_TFOUTPUT)
12431243
self._run_test_case([_OUTPUT], {_INPUT: x_val})
12441244

1245+
@unittest.skipIf(BACKEND in ["caffe2"], "multiple dims not supported")
1246+
def test_strided_slice7(self):
1247+
x_val = np.arange(5*6).astype("float32").reshape(5, 6)
1248+
1249+
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
1250+
x_ = tf.strided_slice(x, [0, 1], [3, 4], [1, 1], begin_mask=2)
1251+
_ = tf.identity(x_, name=_TFOUTPUT)
1252+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1253+
1254+
tf.reset_default_graph()
1255+
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
1256+
x_ = tf.strided_slice(x, [0, 1], [3, 4], [1, 1], end_mask=2)
1257+
_ = tf.identity(x_, name=_TFOUTPUT)
1258+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1259+
1260+
tf.reset_default_graph()
1261+
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
1262+
x_ = tf.strided_slice(x, [0, 1], [3, 4], [1, 1], shrink_axis_mask=2)
1263+
_ = tf.identity(x_, name=_TFOUTPUT)
1264+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1265+
1266+
12451267
@unittest.skipIf(BACKEND in ["caffe2"], "fails with schema error")
12461268
@unittest.skipIf(*support_op_conversion_since(7, "batchnorm"))
12471269
def test_batchnorm(self):

tf2onnx/tfonnx.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -998,11 +998,14 @@ def stridedslice_op(ctx, node, name, args):
998998
if attr is not None and attr.i != 0:
999999
raise ValueError("StridedSlice: attribute " + attr_name + " not supported")
10001000
input_shape = ctx.get_shape(node.input[0])
1001-
begin = node.inputs[1].get_tensor_value()
1002-
end = node.inputs[2].get_tensor_value()
1003-
strides = node.inputs[3].get_tensor_value()
1001+
begin = node.inputs[1].get_tensor_value(as_list=False)
1002+
end = node.inputs[2].get_tensor_value(as_list=False)
1003+
strides = node.inputs[3].get_tensor_value(as_list=False)
1004+
max_size = np.iinfo(begin.dtype).max
10041005
end_mask = node.get_attr("end_mask")
10051006
end_mask = end_mask.i if end_mask is not None else 0
1007+
begin_mask = node.get_attr("begin_mask")
1008+
begin_mask = begin_mask.i if begin_mask is not None else 0
10061009
shrink_axis_mask = node.get_attr("shrink_axis_mask")
10071010
shrink_axis_mask = shrink_axis_mask.i if shrink_axis_mask is not None else 0
10081011
new_begin = []
@@ -1018,19 +1021,25 @@ def stridedslice_op(ctx, node, name, args):
10181021

10191022
# an implicit condition is stride == 1 (checked in above)
10201023
if begin_item < 0 and end_item == 0:
1021-
end_item = sys.maxsize
1024+
end_item = max_size
10221025

10231026
mask = (shrink_axis_mask >> idx) & 1
10241027
if mask != 0:
10251028
new_begin.append(begin_item)
1029+
end_item = begin_item + 1 if begin_item != -1 else max_size
10261030
new_end.append(end_item)
10271031
needs_squeeze.append(idx)
10281032
continue
10291033

1030-
new_begin.append(begin_item)
1034+
mask = (begin_mask >> idx) & 1
1035+
if mask != 0:
1036+
new_begin.append(0)
1037+
else:
1038+
new_begin.append(begin_item)
1039+
10311040
mask = (end_mask >> idx) & 1
10321041
if mask != 0:
1033-
new_end.append(sys.maxsize)
1042+
new_end.append(max_size)
10341043
else:
10351044
new_end.append(end_item)
10361045

0 commit comments

Comments
 (0)