Skip to content

Commit a2b8085

Browse files
authored
Merge pull request #388 from jiafatom/ellipse_mask
Support ellipse mask in stridedslice op
2 parents 5feb8e8 + 2e47674 commit a2b8085

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

tests/test_backend.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1243,6 +1243,12 @@ def test_strided_slice7(self):
12431243
_ = tf.identity(x_, name=_TFOUTPUT)
12441244
self._run_test_case([_OUTPUT], {_INPUT: x_val})
12451245

1246+
tf.reset_default_graph()
1247+
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
1248+
x_ = tf.strided_slice(x, [0, 1], [3, 4], [1, 1], ellipsis_mask=2)
1249+
_ = tf.identity(x_, name=_TFOUTPUT)
1250+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1251+
12461252
@skip_caffe2_backend("fails with schema error")
12471253
@check_opset_min_version(7, "batchnorm")
12481254
def test_batchnorm(self):

tf2onnx/tfonnx.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -926,7 +926,7 @@ def expanddims_op7(ctx, node, name, args):
926926

927927
def stridedslice_op(ctx, node, name, args):
928928
# for now we implement common cases. Things like strides!=1 are not mappable to onnx.
929-
not_supported_attr = ["ellipsis_mask", "new_axis_mask"]
929+
not_supported_attr = ["new_axis_mask"]
930930
for attr_name in not_supported_attr:
931931
attr = node.get_attr(attr_name)
932932
if attr is not None and attr.i != 0:
@@ -942,6 +942,8 @@ def stridedslice_op(ctx, node, name, args):
942942
begin_mask = begin_mask.i if begin_mask is not None else 0
943943
shrink_axis_mask = node.get_attr("shrink_axis_mask")
944944
shrink_axis_mask = shrink_axis_mask.i if shrink_axis_mask is not None else 0
945+
ellipsis_mask = node.get_attr("ellipsis_mask")
946+
ellipsis_mask = ellipsis_mask.i if ellipsis_mask is not None else 0
945947
new_begin = []
946948
new_end = []
947949
axes = []
@@ -953,6 +955,11 @@ def stridedslice_op(ctx, node, name, args):
953955
raise ValueError("StridedSlice: only strides=1 is supported")
954956
axes.append(idx)
955957

958+
if (ellipsis_mask >> idx) & 1:
959+
new_begin.append(0)
960+
new_end.append(max_size)
961+
continue
962+
956963
# an implicit condition is stride == 1 (checked in above)
957964
if begin_item < 0 and end_item == 0:
958965
end_item = max_size

0 commit comments

Comments
 (0)