@@ -162,13 +162,16 @@ struct BlockPrefixCallbackOp<T, LogAddExp> {
162
162
LogAddExp op_;
163
163
164
164
__device__ BlockPrefixCallbackOp (T identity, LogAddExp op)
165
- : max_so_far_(identity), scaled_sum_(0.0 ), compensation_(0.0 ), op_(op) {}
165
+ : max_so_far_(identity),
166
+ scaled_sum_(static_cast <T>(0.0 )),
167
+ compensation_(static_cast <T>(0.0 )),
168
+ op_(op) {}
166
169
167
170
__device__ T operator ()(T block_aggregate) {
168
171
if (scaled_sum_ == 0.0 ) {
169
172
max_so_far_ = block_aggregate;
170
- scaled_sum_ = 1.0 ;
171
- compensation_ = 0.0 ;
173
+ scaled_sum_ = static_cast <T>( 1.0 ) ;
174
+ compensation_ = static_cast <T>( 0.0 ) ;
172
175
return std::numeric_limits<T>::lowest ();
173
176
}
174
177
@@ -255,6 +258,74 @@ __global__ void BlockScanKernel(T* d_out,
255
258
}
256
259
}
257
260
261
+ template <typename Context, typename T>
262
+ void ThrustCumsumKernel (const Context& dev_ctx,
263
+ const T* in_data,
264
+ T* out_data,
265
+ int64_t size,
266
+ bool reverse,
267
+ bool exclusive) {
268
+ using MT = typename phi::dtype::MPTypeTrait<T>::Type;
269
+
270
+ #ifdef __HIPCC__
271
+ const auto & policy = thrust::hip::par.on (dev_ctx.stream ());
272
+ #else
273
+ phi::memory_utils::ThrustAllocator<cudaStream_t> allocator (dev_ctx.GetPlace (),
274
+ dev_ctx.stream ());
275
+ const auto & policy = thrust::cuda::par (allocator).on (dev_ctx.stream ());
276
+ #endif
277
+
278
+ if constexpr (std::is_same_v<T, MT>) {
279
+ if (reverse) {
280
+ thrust::reverse_iterator<thrust::device_ptr<const T>> reversed_in (
281
+ thrust::device_pointer_cast (in_data) + size);
282
+ thrust::reverse_iterator<thrust::device_ptr<T>> reversed_out (
283
+ thrust::device_pointer_cast (out_data) + size);
284
+ if (exclusive) {
285
+ thrust::exclusive_scan (
286
+ policy, reversed_in, reversed_in + size, reversed_out);
287
+ } else {
288
+ thrust::inclusive_scan (
289
+ policy, reversed_in, reversed_in + size, reversed_out);
290
+ }
291
+ } else {
292
+ if (exclusive) {
293
+ thrust::exclusive_scan (policy, in_data, in_data + size, out_data);
294
+ } else {
295
+ thrust::inclusive_scan (policy, in_data, in_data + size, out_data);
296
+ }
297
+ }
298
+ } else {
299
+ thrust::device_vector<MT> tmp_in (size);
300
+ thrust::device_vector<MT> tmp_out (size);
301
+ thrust::copy (policy, in_data, in_data + size, tmp_in.begin ());
302
+
303
+ auto tmp_in_begin = tmp_in.begin ();
304
+ auto tmp_in_end = tmp_in.end ();
305
+ auto tmp_out_begin = tmp_out.begin ();
306
+
307
+ if (reverse) {
308
+ auto reversed_in = tmp_in.rbegin ();
309
+ auto reversed_out = tmp_out.rbegin ();
310
+ if (exclusive) {
311
+ thrust::exclusive_scan (
312
+ policy, reversed_in, reversed_in + size, reversed_out);
313
+ } else {
314
+ thrust::inclusive_scan (
315
+ policy, reversed_in, reversed_in + size, reversed_out);
316
+ }
317
+ } else {
318
+ if (exclusive) {
319
+ thrust::exclusive_scan (policy, tmp_in_begin, tmp_in_end, tmp_out_begin);
320
+ } else {
321
+ thrust::inclusive_scan (policy, tmp_in_begin, tmp_in_end, tmp_out_begin);
322
+ }
323
+ }
324
+
325
+ thrust::copy (policy, tmp_out.begin (), tmp_out.end (), out_data);
326
+ }
327
+ }
328
+
258
329
template <typename T, typename Context, typename Op>
259
330
void ScanKernel (const Context& dev_ctx,
260
331
const DenseTensor& x,
@@ -295,6 +366,15 @@ void ScanKernel(const Context& dev_ctx,
295
366
296
367
const T* in_data = x.data <T>();
297
368
369
+ // Use thrust for parallel acceleration when the input size is equal to the
370
+ // length of the 'axis' dimension (i.e., it's a 1D scan).
371
+ int64_t size = x.numel ();
372
+ if (std::is_same_v<Op, cub::Sum> && size == out_dims[axis]) {
373
+ ThrustCumsumKernel<Context, T>(
374
+ dev_ctx, in_data, out_data, size, reverse, exclusive);
375
+ return ;
376
+ }
377
+
298
378
size_t height = 1 ;
299
379
size_t width = 1 ;
300
380
for (size_t i = 0 ; i <= axis; i++) {
0 commit comments