Skip to content

Commit e7542a4

Browse files
authored
fix stack op grad nullptr (#31962) (#32005)
1 parent b934d0b commit e7542a4

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

paddle/fluid/operators/stack_op.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ struct StackGradFunctor {
3030
int i = idx / (n_ * post_);
3131
int which_x = idx / post_ - i * n_;
3232
int x_index = i * post_ + idx % post_;
33-
dx_[which_x][x_index] = dy_[idx];
33+
if (dx_[which_x] != nullptr) dx_[which_x][x_index] = dy_[idx];
3434
}
3535

3636
private:
@@ -95,19 +95,21 @@ class StackGradKernel : public framework::OpKernel<T> {
9595
auto dx = ctx.MultiOutput<Tensor>(framework::GradVarName("X"));
9696
int axis = ctx.Attr<int>("axis");
9797
if (axis < 0) axis += dy->dims().size();
98-
9998
int n = dy->dims()[axis];
10099
std::vector<T *> dx_datas(n); // NOLINT
100+
101101
for (int i = 0; i < n; i++) {
102-
dx_datas[i] = dx[i]->mutable_data<T>(ctx.GetPlace());
102+
if (dx[i] == nullptr) {
103+
dx_datas[i] = nullptr;
104+
} else {
105+
dx_datas[i] = dx[i]->mutable_data<T>(ctx.GetPlace());
106+
}
103107
}
104108
auto dy_data = dy->data<T>();
105-
106109
int pre = 1;
107110
for (int i = 0; i < axis; ++i) pre *= dy->dims()[i];
108111
int total_num = dy->numel();
109112
int post = total_num / (n * pre);
110-
111113
auto &dev_ctx = ctx.template device_context<DeviceContext>();
112114
auto dx_data_arr = dx_datas.data();
113115
StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n, post);

0 commit comments

Comments
 (0)