Skip to content

Commit ef8f7f3

Browse files
Fix Slice that is Identity (#1476)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 596f237 commit ef8f7f3

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

tests/test_backend.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2540,6 +2540,16 @@ def func2(x):
25402540
x_val = np.arange(np.prod(shape)).astype("float32").reshape(shape)
25412541
self._run_test_case(func2, [_OUTPUT], {_INPUT: x_val})
25422542

2543+
@check_opset_min_version(10, "Slice")
2544+
@skip_tflite("not supported in tflite")
2545+
def test_strided_slice_only_ellipsis(self):
2546+
def func1(x):
2547+
x_ = x[...]
2548+
return tf.identity(x_, name=_TFOUTPUT)
2549+
shape = [1, 8, 64]
2550+
x_val = np.arange(np.prod(shape)).astype("float32").reshape(shape)
2551+
self._run_test_case(func1, [_OUTPUT], {_INPUT: x_val})
2552+
25432553
@check_opset_min_version(7, "batchnorm")
25442554
def test_fused_batchnorm(self):
25452555
x_shape = [1, 28, 28, 2]

tf2onnx/onnx_opset/tensor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -849,8 +849,11 @@ def version_1(cls, ctx, node, **kwargs):
849849
attr = {"starts": new_begin, "ends": new_end, "axes": axes}
850850
inputs_map = {"data": node.input[0], **attr}
851851
kwargs = {**inputs_map, "outputs": node.output}
852-
node = GraphBuilder(ctx).make_slice(
853-
kwargs, name=node.name, dtypes=out_dtypes, shapes=out_shapes, return_node=True)
852+
if len(axes) > 0:
853+
node = GraphBuilder(ctx).make_slice(
854+
kwargs, name=node.name, dtypes=out_dtypes, shapes=out_shapes, return_node=True)
855+
else:
856+
node = ctx.make_node("Identity", [node.input[0]], name=node.name, dtypes=out_dtypes, shapes=out_shapes)
854857
nodes = [node]
855858
if needs_squeeze:
856859
# insert_new_node_on_output(self, op_type, output_name=None, name=None, inputs=None, domain=None, **kwargs)

0 commit comments

Comments
 (0)