Skip to content

Commit 67da5e1

Browse files
authored
fix StridedSlice for ellipsis+newaxis (#1319)
* fix StridedSlice for ellipsis+newaxis Signed-off-by: Guenther Schmuelling <[email protected]> * skip ut for tflite Signed-off-by: Guenther Schmuelling <[email protected]>
1 parent c8d7a3b commit 67da5e1

File tree

3 files changed

+45
-12
lines changed

3 files changed

+45
-12
lines changed

tests/backend_test_base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,17 @@ def run_onnxcaffe2(self, onnx_graph, inputs):
7676
def run_onnxruntime(self, model_path, inputs, output_names):
7777
"""Run test against onnxruntime backend."""
7878
import onnxruntime as rt
79+
providers = ['CPUExecutionProvider']
80+
if rt.get_device() == "GPU":
81+
gpus = os.environ.get("CUDA_VISIBLE_DEVICES")
82+
if gpus is None or len(gpus) > 1:
83+
providers = ['CUDAExecutionProvider']
7984
opt = rt.SessionOptions()
8085
# in case of issues with the runtime, one can enable more logging
8186
# opt.log_severity_level = 0
8287
# opt.log_verbosity_level = 255
8388
# opt.enable_profiling = True
84-
m = rt.InferenceSession(model_path, opt)
89+
m = rt.InferenceSession(model_path, opt, providers=providers)
8590
results = m.run(output_names, inputs)
8691
return results
8792

tests/test_backend.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33

4-
54
"""Unit tests using onnx backends."""
65

76
from __future__ import division
@@ -2230,6 +2229,23 @@ def func(x, y):
22302229
y_val = np.array(9, dtype=np.int32)
22312230
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
22322231

2232+
@check_opset_min_version(10, "Slice")
2233+
@skip_tflite("not supported in tflite")
2234+
def test_strided_slice_ellipse(self):
2235+
def func1(x):
2236+
x_ = x[..., tf.newaxis]
2237+
return tf.identity(x_, name=_TFOUTPUT)
2238+
shape = [1, 8, 64]
2239+
x_val = np.arange(np.prod(shape)).astype("float32").reshape(shape)
2240+
self._run_test_case(func1, [_OUTPUT], {_INPUT: x_val})
2241+
2242+
def func2(x):
2243+
x_ = x[:, tf.newaxis, ..., :, tf.newaxis]
2244+
return tf.identity(x_, name=_TFOUTPUT)
2245+
shape = [2, 3, 4, 5]
2246+
x_val = np.arange(np.prod(shape)).astype("float32").reshape(shape)
2247+
self._run_test_case(func2, [_OUTPUT], {_INPUT: x_val})
2248+
22332249
@check_opset_min_version(7, "batchnorm")
22342250
def test_fused_batchnorm(self):
22352251
x_shape = [1, 28, 28, 2]

tf2onnx/onnx_opset/tensor.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ def version_1(cls, ctx, node, **kwargs):
7272
ctx.copy_shape(output_name, output_cast.output[0])
7373

7474

75-
7675
@tf_op("Flatten")
7776
class Flatten:
7877
@classmethod
@@ -630,6 +629,7 @@ def version_13(cls, ctx, node, **kwargs):
630629
# Default axis is not -1 but doesn't matter since we always set it.
631630
cls.version_1(ctx, node, **kwargs)
632631

632+
633633
@tf_op("SplitV")
634634
class SplitV:
635635
@classmethod
@@ -874,15 +874,6 @@ def any_version_after10(cls, opset, ctx, node, **kwargs):
874874
ellipsis_mask = ellipsis_mask.i if ellipsis_mask is not None else 0
875875
shrink_axis_mask = node.get_attr("shrink_axis_mask")
876876
shrink_axis_mask = shrink_axis_mask.i if shrink_axis_mask is not None else 0
877-
if new_axis_mask != 0:
878-
unqueeze_at = []
879-
for bit in range(32):
880-
if (new_axis_mask >> bit) & 1 == 1:
881-
unqueeze_at.append(bit)
882-
begin_mask |= 1 << bit
883-
end_mask |= 1 << bit
884-
input_x = GraphBuilder(ctx).make_unsqueeze(
885-
{'data': input_x.output[0], 'axes': unqueeze_at}, return_node=True)
886877

887878
param_shape = ctx.get_shape(node.input[1]) or \
888879
ctx.get_shape(node.input[2]) or \
@@ -892,6 +883,27 @@ def any_version_after10(cls, opset, ctx, node, **kwargs):
892883
"StridedSlice op {} requires the shape of begin/end/strides".format(node.name)
893884
)
894885
param_rank = param_shape[0]
886+
887+
if new_axis_mask != 0:
888+
unqueeze_at = []
889+
ellipsis_gap = 0
890+
num_new = 0
891+
for bit in range(32):
892+
if (new_axis_mask >> bit) & 1 == 1:
893+
num_new += 1
894+
if (ellipsis_mask >> bit) & 1:
895+
input_shape = ctx.get_shape(input_x.output[0])
896+
# calculate what rank for ellipsis: input rank - (being rank - all new_axis - 1)
897+
ellipsis_gap = len(input_shape) - param_rank + num_new + 1
898+
if (new_axis_mask >> bit) & 1 == 1:
899+
unqueeze_at.append(bit + ellipsis_gap)
900+
begin_mask |= 1 << bit
901+
end_mask |= 1 << bit
902+
903+
input_x = GraphBuilder(ctx).make_unsqueeze(
904+
{'data': input_x.output[0], 'axes': unqueeze_at}, return_node=True)
905+
906+
895907
# use in onnx graph to mask begin
896908
new_begin_mask = [1] * param_rank
897909
# use in onnx graph to mask end

0 commit comments

Comments
 (0)