Skip to content

Commit c1488b1

Browse files
author
Yibing Liu
authored
Merge pull request #12940 from sneaxiy/stack_op
Speedup stack_op
2 parents dbd7896 + 3b38e5a commit c1488b1

File tree

1 file changed

+14
-42
lines changed

1 file changed

+14
-42
lines changed

paddle/fluid/operators/stack_op.h

Lines changed: 14 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -150,30 +150,17 @@ class StackKernel : public framework::OpKernel<T> {
150150
int total_num = pre * n * post;
151151

152152
auto &dev_ctx = ctx.template device_context<DeviceContext>();
153-
constexpr auto kMaxThreshold = 16;
154-
if (std::is_same<DeviceContext, platform::CPUDeviceContext>::value ||
155-
n > kMaxThreshold) {
156153
#ifdef __NVCC__
157-
VLOG(10) << "Stack more than " << kMaxThreshold
158-
<< " tensors on GPU may be slow.";
159-
thrust::device_vector<const T *> device_x_vec(x_datas);
160-
auto x_data_arr = device_x_vec.data().get();
154+
thrust::device_vector<const T *> device_x_vec(x_datas);
155+
auto x_data_arr = device_x_vec.data().get();
161156
#else
162-
auto x_data_arr = x_datas.data();
157+
auto x_data_arr = x_datas.data();
163158
#endif
164-
StackFunctorForRange(dev_ctx, x_data_arr, y_data, total_num, n, post);
159+
StackFunctorForRange(dev_ctx, x_data_arr, y_data, total_num, n, post);
165160
#ifdef __NVCC__
166-
// Wait() must be called because device_x_vec may be destructed before
167-
// kernel ends
168-
dev_ctx.Wait();
169-
#endif
170-
}
171-
#ifdef __NVCC__
172-
else { // NOLINT
173-
framework::Array<const T *, kMaxThreshold> x_data_arr;
174-
for (int i = 0; i < n; ++i) x_data_arr[i] = x_datas[i];
175-
StackFunctorForRange(dev_ctx, x_data_arr, y_data, total_num, n, post);
176-
}
161+
// Wait() must be called because device_x_vec may be destructed before
162+
// kernel ends
163+
dev_ctx.Wait();
177164
#endif
178165
}
179166
};
@@ -244,32 +231,17 @@ class StackGradKernel : public framework::OpKernel<T> {
244231
int post = total_num / (n * pre);
245232

246233
auto &dev_ctx = ctx.template device_context<DeviceContext>();
247-
constexpr auto kMaxThreshold = 16;
248-
if (std::is_same<DeviceContext, platform::CPUDeviceContext>::value ||
249-
n > kMaxThreshold) {
250234
#ifdef __NVCC__
251-
VLOG(10) << "Stack more than " << kMaxThreshold
252-
<< " tensors on GPU may be slow.";
253-
thrust::device_vector<T *> device_dx_vec(dx_datas);
254-
auto dx_data_arr = device_dx_vec.data().get();
235+
thrust::device_vector<T *> device_dx_vec(dx_datas);
236+
auto dx_data_arr = device_dx_vec.data().get();
255237
#else
256-
auto dx_data_arr = dx_datas.data();
238+
auto dx_data_arr = dx_datas.data();
257239
#endif
258-
StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n,
259-
post);
240+
StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n, post);
260241
#ifdef __NVCC__
261-
// Wait() must be called because device_dx_vec may be destructed before
262-
// kernel ends
263-
dev_ctx.Wait();
264-
#endif
265-
}
266-
#ifdef __NVCC__
267-
else { // NOLINT
268-
framework::Array<T *, kMaxThreshold> dx_data_arr;
269-
for (int i = 0; i < n; ++i) dx_data_arr[i] = dx_datas[i];
270-
StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n,
271-
post);
272-
}
242+
// Wait() must be called because device_dx_vec may be destructed before
243+
// kernel ends
244+
dev_ctx.Wait();
273245
#endif
274246
}
275247
};

0 commit comments

Comments
 (0)