Skip to content

Commit 03ccb9a

Browse files
committed
Optimize the stack operator
1 parent 7a64d48 commit 03ccb9a

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

paddle/fluid/operators/stack_op.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,16 +147,23 @@ class StackKernel : public framework::OpKernel<T> {
147147
auto &dim = x[0]->dims();
148148
for (auto i = 0; i < axis; ++i) pre *= dim[i];
149149
for (auto i = axis; i < dim.size(); ++i) post *= dim[i];
150-
int total_num = pre * n * post;
151150

152-
auto &dev_ctx = ctx.template device_context<DeviceContext>();
153151
#ifdef __NVCC__
154152
thrust::device_vector<const T *> device_x_vec(x_datas);
155153
auto x_data_arr = device_x_vec.data().get();
156154
#else
157155
auto x_data_arr = x_datas.data();
158156
#endif
159-
StackFunctorForRange(dev_ctx, x_data_arr, y_data, total_num, n, post);
157+
size_t x_offset = 0;
158+
size_t y_offset = 0;
159+
for (int i = 0; i < pre; i++) {
160+
for (int j = 0; j < n; j++) {
161+
std::memcpy(y_data + y_offset, x_data_arr[j] + x_offset,
162+
post * sizeof(T));
163+
y_offset += post;
164+
}
165+
x_offset += post;
166+
}
160167
#ifdef __NVCC__
161168
// Wait() must be called because device_x_vec may be destructed before
162169
// kernel ends

0 commit comments

Comments
 (0)