Skip to content

Commit 212e119

Browse files
authored
Merge pull request #891 from onnx/gs/fix-negative-split
fix split in case of splits are negavitve
2 parents c2eef98 + 07a5ac3 commit 212e119

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

tests/test_backend.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1160,6 +1160,13 @@ def func(x):
11601160
return tf.split(x, [4, 15, 11], 1, name="split_test")
11611161
self._run_test_case(func, ["split_test:0", "split_test:1", "split_test:2"], {_INPUT: x_val})
11621162

1163+
def test_negative_split(self):
1164+
x_val = np.linspace(1.0, 5 * 30.0, 5 * 30).astype(np.float32).reshape((5, 30))
1165+
def func(x):
1166+
x_, _, _ = tf.split(x, [4, 15, -1], 1)
1167+
return tf.identity(x_, name=_TFOUTPUT)
1168+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
1169+
11631170
def test_reducesum(self):
11641171
# not supported by onnx-caffe2
11651172
x_val = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32).reshape((2, 2))

tf2onnx/onnx_opset/tensor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,14 @@ def version_1(cls, ctx, node, **kwargs):
553553
node.type = "Split"
554554
split = node.inputs[1].get_tensor_value()
555555
split_dims = node.inputs[2].get_tensor_value()
556+
if -1 in split:
557+
# negative split = use the remaining size
558+
shape = ctx.get_shape(node.input[0])
559+
final_sum = shape[split_dims]
560+
sums = sum([i for i in split if i >= 0])
561+
for i, v in enumerate(split):
562+
if v == -1:
563+
split[i] = final_sum - sums
556564
ctx.remove_input(node, node.input[2])
557565
ctx.remove_input(node, node.input[1])
558566
node.set_attr("split", split)

0 commit comments

Comments
 (0)