Skip to content

Commit 3866084

Browse files
committed
Merge pull request #16797 from phlrain/fix_split
Fix split
1 parent 7b45363 commit 3866084

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

paddle/fluid/operators/split_op.cc

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,22 @@ class SplitOp : public framework::OperatorWithKernel {
3939

4040
if (num > 0) {
4141
int64_t in_axis_dim = in_dims[axis];
42-
PADDLE_ENFORCE_EQ(in_axis_dim % num, 0,
43-
"tensor split does not result"
44-
" in an equal division");
45-
size_t out_axis_dim = in_axis_dim / num;
46-
for (size_t i = 0; i < outs_number; ++i) {
47-
auto dim = in_dims;
48-
dim[axis] = out_axis_dim;
49-
outs_dims.push_back(dim);
42+
if (ctx->IsRuntime() || in_axis_dim > 0) {
43+
PADDLE_ENFORCE_EQ(in_axis_dim % num, 0,
44+
"tensor split does not result"
45+
" in an equal division");
46+
size_t out_axis_dim = in_axis_dim / num;
47+
for (size_t i = 0; i < outs_number; ++i) {
48+
auto dim = in_dims;
49+
dim[axis] = out_axis_dim;
50+
outs_dims.push_back(dim);
51+
}
52+
} else {
53+
for (size_t i = 0; i < outs_number; ++i) {
54+
auto dim = in_dims;
55+
dim[axis] = -1;
56+
outs_dims.push_back(dim);
57+
}
5058
}
5159
} else if (sections.size() > 0) {
5260
PADDLE_ENFORCE_EQ(sections.size(), outs_number,

0 commit comments

Comments
 (0)