Skip to content

Commit 1d3e9bd

Browse files
authored
Merge pull request #14488 from yihuaxu/develop_7a64d48f5_stack_opt
Optimize the stack operator
2 parents 8bc1c5d + a906a36 commit 1d3e9bd

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

paddle/fluid/operators/stack_op.h

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -147,20 +147,32 @@ class StackKernel : public framework::OpKernel<T> {
147147
auto &dim = x[0]->dims();
148148
for (auto i = 0; i < axis; ++i) pre *= dim[i];
149149
for (auto i = axis; i < dim.size(); ++i) post *= dim[i];
150-
int total_num = pre * n * post;
151150

152-
auto &dev_ctx = ctx.template device_context<DeviceContext>();
153151
#ifdef __NVCC__
152+
int total_num = pre * n * post;
153+
auto &dev_ctx = ctx.template device_context<DeviceContext>();
154+
154155
thrust::device_vector<const T *> device_x_vec(x_datas);
155156
auto x_data_arr = device_x_vec.data().get();
156-
#else
157-
auto x_data_arr = x_datas.data();
158-
#endif
157+
159158
StackFunctorForRange(dev_ctx, x_data_arr, y_data, total_num, n, post);
160-
#ifdef __NVCC__
159+
161160
// Wait() must be called because device_x_vec may be destructed before
162161
// kernel ends
163162
dev_ctx.Wait();
163+
#else
164+
auto x_data_arr = x_datas.data();
165+
166+
size_t x_offset = 0;
167+
size_t y_offset = 0;
168+
for (int i = 0; i < pre; i++) {
169+
for (int j = 0; j < n; j++) {
170+
std::memcpy(y_data + y_offset, x_data_arr[j] + x_offset,
171+
post * sizeof(T));
172+
y_offset += post;
173+
}
174+
x_offset += post;
175+
}
164176
#endif
165177
}
166178
};

0 commit comments

Comments
 (0)