Skip to content

Commit 9db2cad

Browse files
[PHI] Fix paddle.cumsum calculation speed (#74442)
* fix ThrustCumsumKernel * refine * refine ThrustCumsumKernel * fix * update ThrustCumsumKernel * fix logcumsumexp in ThrustCumsumKernel
1 parent 65e2105 commit 9db2cad

File tree

1 file changed

+83
-3
lines changed

1 file changed

+83
-3
lines changed

paddle/phi/kernels/gpu/cum_kernel.cu

Lines changed: 83 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,13 +162,16 @@ struct BlockPrefixCallbackOp<T, LogAddExp> {
162162
LogAddExp op_;
163163

164164
__device__ BlockPrefixCallbackOp(T identity, LogAddExp op)
165-
: max_so_far_(identity), scaled_sum_(0.0), compensation_(0.0), op_(op) {}
165+
: max_so_far_(identity),
166+
scaled_sum_(static_cast<T>(0.0)),
167+
compensation_(static_cast<T>(0.0)),
168+
op_(op) {}
166169

167170
__device__ T operator()(T block_aggregate) {
168171
if (scaled_sum_ == 0.0) {
169172
max_so_far_ = block_aggregate;
170-
scaled_sum_ = 1.0;
171-
compensation_ = 0.0;
173+
scaled_sum_ = static_cast<T>(1.0);
174+
compensation_ = static_cast<T>(0.0);
172175
return std::numeric_limits<T>::lowest();
173176
}
174177

@@ -255,6 +258,74 @@ __global__ void BlockScanKernel(T* d_out,
255258
}
256259
}
257260

261+
template <typename Context, typename T>
262+
void ThrustCumsumKernel(const Context& dev_ctx,
263+
const T* in_data,
264+
T* out_data,
265+
int64_t size,
266+
bool reverse,
267+
bool exclusive) {
268+
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
269+
270+
#ifdef __HIPCC__
271+
const auto& policy = thrust::hip::par.on(dev_ctx.stream());
272+
#else
273+
phi::memory_utils::ThrustAllocator<cudaStream_t> allocator(dev_ctx.GetPlace(),
274+
dev_ctx.stream());
275+
const auto& policy = thrust::cuda::par(allocator).on(dev_ctx.stream());
276+
#endif
277+
278+
if constexpr (std::is_same_v<T, MT>) {
279+
if (reverse) {
280+
thrust::reverse_iterator<thrust::device_ptr<const T>> reversed_in(
281+
thrust::device_pointer_cast(in_data) + size);
282+
thrust::reverse_iterator<thrust::device_ptr<T>> reversed_out(
283+
thrust::device_pointer_cast(out_data) + size);
284+
if (exclusive) {
285+
thrust::exclusive_scan(
286+
policy, reversed_in, reversed_in + size, reversed_out);
287+
} else {
288+
thrust::inclusive_scan(
289+
policy, reversed_in, reversed_in + size, reversed_out);
290+
}
291+
} else {
292+
if (exclusive) {
293+
thrust::exclusive_scan(policy, in_data, in_data + size, out_data);
294+
} else {
295+
thrust::inclusive_scan(policy, in_data, in_data + size, out_data);
296+
}
297+
}
298+
} else {
299+
thrust::device_vector<MT> tmp_in(size);
300+
thrust::device_vector<MT> tmp_out(size);
301+
thrust::copy(policy, in_data, in_data + size, tmp_in.begin());
302+
303+
auto tmp_in_begin = tmp_in.begin();
304+
auto tmp_in_end = tmp_in.end();
305+
auto tmp_out_begin = tmp_out.begin();
306+
307+
if (reverse) {
308+
auto reversed_in = tmp_in.rbegin();
309+
auto reversed_out = tmp_out.rbegin();
310+
if (exclusive) {
311+
thrust::exclusive_scan(
312+
policy, reversed_in, reversed_in + size, reversed_out);
313+
} else {
314+
thrust::inclusive_scan(
315+
policy, reversed_in, reversed_in + size, reversed_out);
316+
}
317+
} else {
318+
if (exclusive) {
319+
thrust::exclusive_scan(policy, tmp_in_begin, tmp_in_end, tmp_out_begin);
320+
} else {
321+
thrust::inclusive_scan(policy, tmp_in_begin, tmp_in_end, tmp_out_begin);
322+
}
323+
}
324+
325+
thrust::copy(policy, tmp_out.begin(), tmp_out.end(), out_data);
326+
}
327+
}
328+
258329
template <typename T, typename Context, typename Op>
259330
void ScanKernel(const Context& dev_ctx,
260331
const DenseTensor& x,
@@ -295,6 +366,15 @@ void ScanKernel(const Context& dev_ctx,
295366

296367
const T* in_data = x.data<T>();
297368

369+
// Use thrust for parallel acceleration when the input size is equal to the
370+
// length of the 'axis' dimension (i.e., it's a 1D scan).
371+
int64_t size = x.numel();
372+
if (std::is_same_v<Op, cub::Sum> && size == out_dims[axis]) {
373+
ThrustCumsumKernel<Context, T>(
374+
dev_ctx, in_data, out_data, size, reverse, exclusive);
375+
return;
376+
}
377+
298378
size_t height = 1;
299379
size_t width = 1;
300380
for (size_t i = 0; i <= axis; i++) {

0 commit comments

Comments
 (0)