Skip to content

Commit a906a36

Browse files
committed
Add the macro for NVCC (test=develop)
1 parent d91740a commit a906a36

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

paddle/fluid/operators/stack_op.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -149,11 +149,20 @@ class StackKernel : public framework::OpKernel<T> {
149149
for (auto i = axis; i < dim.size(); ++i) post *= dim[i];
150150

151151
#ifdef __NVCC__
152+
int total_num = pre * n * post;
153+
auto &dev_ctx = ctx.template device_context<DeviceContext>();
154+
152155
thrust::device_vector<const T *> device_x_vec(x_datas);
153156
auto x_data_arr = device_x_vec.data().get();
157+
158+
StackFunctorForRange(dev_ctx, x_data_arr, y_data, total_num, n, post);
159+
160+
// Wait() must be called because device_x_vec may be destructed before
161+
// kernel ends
162+
dev_ctx.Wait();
154163
#else
155164
auto x_data_arr = x_datas.data();
156-
#endif
165+
157166
size_t x_offset = 0;
158167
size_t y_offset = 0;
159168
for (int i = 0; i < pre; i++) {
@@ -164,10 +173,6 @@ class StackKernel : public framework::OpKernel<T> {
164173
}
165174
x_offset += post;
166175
}
167-
#ifdef __NVCC__
168-
// Wait() must be called because device_x_vec may be destructed before
169-
// kernel ends
170-
dev_ctx.Wait();
171176
#endif
172177
}
173178
};

0 commit comments

Comments
 (0)