Skip to content

Commit ef0af82

Browse files
author
wayuanho
authored
Merge pull request #589 from lucienwang1009/ellipse_mask
fix ellipsis_mask bug
2 parents afab9d7 + d4b9bf2 commit ef0af82

File tree

2 files changed

+93
-15
lines changed

2 files changed

+93
-15
lines changed

tests/test_backend.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1621,6 +1621,28 @@ def test_strided_slice7(self):
16211621
_ = tf.identity(x_, name=_TFOUTPUT)
16221622
self._run_test_case([_OUTPUT], {_INPUT: x_val})
16231623

1624+
@skip_caffe2_backend("multiple dims not supported")
1625+
def test_strided_slice8(self):
1626+
x_val = np.arange(1 * 2 * 3 * 4 * 5 * 6).astype("float32").reshape((1, 2, 3, 4, 5, 6))
1627+
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
1628+
x_ = x[0:1, ..., 1, 2:, :6]
1629+
_ = tf.identity(x_, name=_TFOUTPUT)
1630+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1631+
1632+
tf.reset_default_graph()
1633+
x_val = np.arange(1 * 2 * 3 * 4 * 5 * 6).astype("float32").reshape((1, 2, 3, 4, 5, 6))
1634+
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
1635+
x_ = x[0:1, 1, 2:, :6, ...]
1636+
_ = tf.identity(x_, name=_TFOUTPUT)
1637+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1638+
1639+
tf.reset_default_graph()
1640+
x_val = np.arange(1 * 2 * 3 * 4 * 5 * 6).astype("float32").reshape((1, 2, 3, 4, 5, 6))
1641+
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
1642+
x_ = x[..., 0:1, 1, 2:, :6]
1643+
_ = tf.identity(x_, name=_TFOUTPUT)
1644+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1645+
16241646
@check_opset_min_version(10, "Slice")
16251647
@skip_caffe2_backend("multiple dims not supported")
16261648
def test_strided_slice_dynamic_1(self):
@@ -1702,6 +1724,35 @@ def test_strided_slice_dynamic_6(self):
17021724
_ = tf.identity(x_, name=_TFOUTPUT)
17031725
self._run_test_case([_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
17041726

1727+
@check_opset_min_version(10, "Slice")
1728+
@skip_caffe2_backend("multiple dims not supported")
1729+
def test_strided_slice_dynamic_7(self):
1730+
x_val = np.arange(1 * 2 * 3 * 4 * 5 * 6).astype("float32").reshape((1, 2, 3, 4, 5, 6))
1731+
y_val = np.array(1, dtype=np.int32)
1732+
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
1733+
y = tf.placeholder(tf.int32, y_val.shape, name=_TFINPUT1)
1734+
x_ = x[0:y, ..., y, y:, :y]
1735+
_ = tf.identity(x_, name=_TFOUTPUT)
1736+
self._run_test_case([_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
1737+
1738+
tf.reset_default_graph()
1739+
x_val = np.arange(1 * 2 * 3 * 4 * 5 * 6).astype("float32").reshape((1, 2, 3, 4, 5, 6))
1740+
y_val = np.array(1, dtype=np.int32)
1741+
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
1742+
y = tf.placeholder(tf.int32, y_val.shape, name=_TFINPUT1)
1743+
x_ = x[0:y, y, y:, :y, ...]
1744+
_ = tf.identity(x_, name=_TFOUTPUT)
1745+
self._run_test_case([_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
1746+
1747+
tf.reset_default_graph()
1748+
x_val = np.arange(1 * 2 * 3 * 4 * 5 * 6).astype("float32").reshape((1, 2, 3, 4, 5, 6))
1749+
y_val = np.array(1, dtype=np.int32)
1750+
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
1751+
y = tf.placeholder(tf.int32, y_val.shape, name=_TFINPUT1)
1752+
x_ = x[..., 0:y, y, y:, :y]
1753+
_ = tf.identity(x_, name=_TFOUTPUT)
1754+
self._run_test_case([_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
1755+
17051756
@skip_caffe2_backend("fails with schema error")
17061757
@check_opset_min_version(7, "batchnorm")
17071758
def test_batchnorm(self):

tf2onnx/onnx_opset/tensor.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,7 @@ def version_1(cls, ctx, node, **kwargs):
567567
attr = node.get_attr(attr_name)
568568
if attr is not None and attr.i != 0:
569569
raise ValueError("StridedSlice: attribute " + attr_name + " not supported")
570+
570571
onnx_dtype = ctx.get_dtype(node.input[1])
571572
np_dtype = utils.ONNX_TO_NUMPY_DTYPE[onnx_dtype]
572573
max_size = np.iinfo(np_dtype).max
@@ -586,17 +587,27 @@ def version_1(cls, ctx, node, **kwargs):
586587
axes = []
587588
# onnx slice op can't remove a axis, track axis and add a squeeze op if needed
588589
needs_squeeze = []
590+
# ellipsis: one bit at most can be 1. An ellipsis implicitly creates as many range specifications as
591+
# necessary to fully specify the sliced range for every dimension.
592+
# For example for a 4-dimensional tensor foo the slice foo[2, ..., 5:8] implies foo[2, :, :, 5:8]
593+
# NOTE: we ignore those axes denoted by ellipsis using `axes` attribute
594+
ellipsis_gap = 0
589595
for idx, begin_item in enumerate(begin):
590-
end_item = end[idx]
591596
if strides[idx] != 1:
592597
raise ValueError("StridedSlice: only strides=1 is supported")
593-
axes.append(idx)
594-
595598
if (ellipsis_mask >> idx) & 1:
596-
new_begin.append(0)
597-
new_end.append(max_size)
599+
input_shape = ctx.get_shape(node.input[0])
600+
utils.make_sure(
601+
input_shape is not None,
602+
"StridedSlice op {} requires the shape of input".format(node.name)
603+
)
604+
ellipsis_gap = len(input_shape) - len(begin)
598605
continue
599606

607+
# ignore ellipsis axes
608+
axes.append(idx + ellipsis_gap)
609+
end_item = end[idx]
610+
600611
# an implicit condition is stride == 1 (checked in above)
601612
if begin_item < 0 and end_item == 0:
602613
end_item = max_size
@@ -606,7 +617,7 @@ def version_1(cls, ctx, node, **kwargs):
606617
new_begin.append(begin_item)
607618
end_item = begin_item + 1 if begin_item != -1 else max_size
608619
new_end.append(end_item)
609-
needs_squeeze.append(idx)
620+
needs_squeeze.append(idx + ellipsis_gap)
610621
continue
611622

612623
mask = (begin_mask >> idx) & 1
@@ -698,30 +709,46 @@ def version_10(cls, ctx, node, **kwargs):
698709
ellipsis_mask = ellipsis_mask.i if ellipsis_mask is not None else 0
699710
shrink_axis_mask = node.get_attr("shrink_axis_mask")
700711
shrink_axis_mask = shrink_axis_mask.i if shrink_axis_mask is not None else 0
701-
input_shape = ctx.get_shape(node.input[1]) or \
712+
param_shape = ctx.get_shape(node.input[1]) or \
702713
ctx.get_shape(node.input[2]) or \
703714
ctx.get_shape(node.input[3])
704-
utils.make_sure(input_shape, "StridedSlice op {} requires the shape of begin/end/strides".format(node.name))
705-
input_rank = input_shape[0]
715+
utils.make_sure(
716+
param_shape is not None,
717+
"StridedSlice op {} requires the shape of begin/end/strides".format(node.name)
718+
)
719+
param_rank = param_shape[0]
706720
# use in onnx graph to mask begin
707-
new_begin_mask = [1] * input_rank
721+
new_begin_mask = [1] * param_rank
708722
# use in onnx graph to mask end
709-
new_end_mask = [min_size] * input_rank
723+
new_end_mask = [min_size] * param_rank
710724
# for shrink mask, if shrink mask is 1, set stride to be max_size
711-
shrink_strided_mask = [min_size] * input_rank
725+
shrink_strided_mask = [min_size] * param_rank
726+
axes = []
712727
# onnx slice op can't remove a axis, track axis and add a squeeze op if needed
713728
needs_squeeze = []
714-
for idx in range(input_rank):
729+
ellipsis_gap = 0
730+
for idx in range(param_rank):
715731
if (ellipsis_mask >> idx) & 1:
732+
input_shape = ctx.get_shape(node.input[0])
733+
utils.make_sure(
734+
input_shape is not None,
735+
"StridedSlice op {} requires the shape of input".format(node.name)
736+
)
737+
ellipsis_gap = len(input_shape) - param_rank
738+
# handle the redundant param
716739
new_begin_mask[idx] = 0
717740
new_end_mask[idx] = max_size
741+
axes.append(idx)
718742
continue
719743

744+
# ignore ellipsis axes
745+
axes.append(idx + ellipsis_gap)
746+
720747
mask = (shrink_axis_mask >> idx) & 1
721748
if mask != 0:
722749
shrink_strided_mask[idx] = max_size
723750
new_end_mask[idx] = max_size
724-
needs_squeeze.append(idx)
751+
needs_squeeze.append(idx + ellipsis_gap)
725752
continue
726753

727754
mask = (begin_mask >> idx) & 1
@@ -794,7 +821,7 @@ def version_10(cls, ctx, node, **kwargs):
794821
# create axes input
795822
axes_const = ctx.make_const(
796823
utils.make_name("slice_axes"),
797-
np.arange(input_rank, dtype=np_dtype)
824+
np.array(axes, dtype=np_dtype)
798825
)
799826
axes_output = axes_const.output[0]
800827

0 commit comments

Comments
 (0)