Skip to content

Commit 3063449

Browse files
authored
Merge pull request #16888 from jacquesqiao/cherry-pick-fix-split_byref_op_reshape
fix split_byref_op infer shape
2 parents 2644ef3 + 7e3f812 commit 3063449

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

paddle/fluid/operators/distributed_ops/split_byref_op.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,16 @@ class SplitByrefOp : public framework::OperatorWithKernel {
3131
auto in_dims = ctx->GetInputDim("X");
3232
auto outs_names = ctx->Outputs("Out");
3333
size_t num = static_cast<size_t>(ctx->Attrs().Get<int>("num"));
34-
std::vector<int> sections = static_cast<std::vector<int>>(
35-
ctx->Attrs().Get<std::vector<int>>("sections"));
34+
auto sections = ctx->Attrs().Get<std::vector<int>>("sections");
3635
const size_t outs_number = outs_names.size();
3736
std::vector<framework::DDim> outs_dims;
3837
outs_dims.reserve(outs_number);
3938

4039
if (num > 0) {
41-
int64_t in_axis_dim = in_dims[0];
40+
int64_t in_axis_dim = 0;
41+
if (ctx->IsRuntime()) {
42+
in_axis_dim = in_dims[0];
43+
}
4244
PADDLE_ENFORCE_EQ(in_axis_dim % num, 0,
4345
"tensor split does not result"
4446
" in an equal division");

0 commit comments

Comments
 (0)