Skip to content

Commit 6bc515d

Browse files
authored
Use cooperative-groups for warp-parallel kernels in strings functions (rapidsai#18959)
Replaces some warp-parallel logic in strings internal functions to use cooperative groups instead. Authors: - David Wendt (https://github.com/davidwendt) Approvers: - Muhammad Haseeb (https://github.com/mhaseeb123) - Yunsong Wang (https://github.com/PointKernel) URL: rapidsai#18959
1 parent 74e6724 commit 6bc515d

File tree

4 files changed

+45
-49
lines changed

4 files changed

+45
-49
lines changed

cpp/src/strings/case.cu

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@
3535
#include <rmm/cuda_stream_view.hpp>
3636
#include <rmm/exec_policy.hpp>
3737

38-
#include <cub/cub.cuh>
38+
#include <cooperative_groups.h>
39+
#include <cooperative_groups/reduce.h>
3940
#include <cuda/atomic>
4041
#include <cuda/functional>
4142
#include <thrust/binary_search.h>
@@ -285,17 +286,13 @@ CUDF_KERNEL void count_bytes_kernel(convert_char_fn converter,
285286
column_device_view d_strings,
286287
size_type* d_sizes)
287288
{
288-
auto idx = cudf::detail::grid_1d::global_thread_id();
289-
if (idx >= (d_strings.size() * cudf::detail::warp_size)) { return; }
289+
namespace cg = cooperative_groups;
290+
auto const warp = cg::tiled_partition<cudf::detail::warp_size>(cg::this_thread_block());
291+
auto const lane_idx = warp.thread_rank();
290292

291-
auto const str_idx = idx / cudf::detail::warp_size;
292-
auto const lane_idx = idx % cudf::detail::warp_size;
293+
auto const str_idx = warp.meta_group_rank();
294+
if (str_idx >= d_strings.size() or d_strings.is_null(str_idx)) { return; }
293295

294-
// initialize the output for the atomicAdd
295-
if (lane_idx == 0) { d_sizes[str_idx] = 0; }
296-
__syncwarp();
297-
298-
if (d_strings.is_null(str_idx)) { return; }
299296
auto const d_str = d_strings.element<string_view>(str_idx);
300297
auto const str_ptr = d_str.data();
301298

@@ -311,11 +308,9 @@ CUDF_KERNEL void count_bytes_kernel(convert_char_fn converter,
311308
size += converter.process_character(u8);
312309
}
313310
}
314-
// this is slightly faster than using the cub::warp_reduce
315-
if (size > 0) {
316-
cuda::atomic_ref<size_type, cuda::thread_scope_block> ref{*(d_sizes + str_idx)};
317-
ref.fetch_add(size, cuda::std::memory_order_relaxed);
318-
}
311+
312+
auto out_size = cg::reduce(warp, size, cg::plus<size_type>());
313+
if (lane_idx == 0) { d_sizes[str_idx] = out_size; }
319314
}
320315

321316
/**

cpp/src/strings/convert/convert_urls.cu

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
2+
* Copyright (c) 2019-2025, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -33,6 +33,9 @@
3333
#include <rmm/cuda_stream_view.hpp>
3434
#include <rmm/device_uvector.hpp>
3535

36+
#include <cooperative_groups.h>
37+
#include <cooperative_groups/reduce.h>
38+
#include <cooperative_groups/scan.h>
3639
#include <cub/cub.cuh>
3740

3841
namespace cudf {
@@ -202,11 +205,14 @@ CUDF_KERNEL void url_decode_char_counter(column_device_view const in_strings,
202205
__shared__ char temporary_buffer[num_warps_per_threadblock][char_block_size + halo_size];
203206
__shared__ typename cub::WarpReduce<int8_t>::TempStorage cub_storage[num_warps_per_threadblock];
204207

205-
auto const global_thread_id =
206-
cudf::detail::grid_1d::global_thread_id<num_warps_per_threadblock * cudf::detail::warp_size>();
207-
auto const global_warp_id = static_cast<size_type>(global_thread_id / cudf::detail::warp_size);
208-
auto const local_warp_id = static_cast<size_type>(threadIdx.x / cudf::detail::warp_size);
209-
auto const warp_lane = static_cast<size_type>(threadIdx.x % cudf::detail::warp_size);
208+
namespace cg = cooperative_groups;
209+
auto const block = cg::this_thread_block();
210+
auto const warp = cg::tiled_partition<cudf::detail::warp_size>(block);
211+
212+
auto const global_thread_id = cudf::detail::grid_1d::global_thread_id();
213+
auto const global_warp_id = static_cast<size_type>(global_thread_id / cudf::detail::warp_size);
214+
auto const local_warp_id = static_cast<size_type>(warp.meta_group_rank());
215+
auto const warp_lane = static_cast<size_type>(warp.thread_rank());
210216
auto const nwarps =
211217
static_cast<size_type>(cudf::detail::grid_1d::grid_stride() / cudf::detail::warp_size);
212218
char* in_chars_shared = temporary_buffer[local_warp_id];
@@ -241,7 +247,7 @@ CUDF_KERNEL void url_decode_char_counter(column_device_view const in_strings,
241247
in_chars_shared[char_idx] = in_idx < string_length ? in_chars[in_idx] : 0;
242248
}
243249

244-
__syncwarp();
250+
warp.sync();
245251

246252
// `char_idx_start` represents the start character index of the current warp.
247253
for (size_type char_idx_start = 0; char_idx_start < string_length_block;
@@ -258,7 +264,7 @@ CUDF_KERNEL void url_decode_char_counter(column_device_view const in_strings,
258264

259265
if (warp_lane == 0) { escape_char_count += total_escape_char; }
260266

261-
__syncwarp();
267+
warp.sync();
262268
}
263269
}
264270
// URL decoding replaces 3 bytes with 1 for each escape character.
@@ -289,11 +295,14 @@ CUDF_KERNEL void url_decode_char_replacer(column_device_view const in_strings,
289295
__shared__ typename cub::WarpScan<int8_t>::TempStorage cub_storage[num_warps_per_threadblock];
290296
__shared__ size_type out_idx[num_warps_per_threadblock];
291297

292-
auto const global_thread_id =
293-
cudf::detail::grid_1d::global_thread_id<num_warps_per_threadblock * cudf::detail::warp_size>();
294-
auto const global_warp_id = static_cast<size_type>(global_thread_id / cudf::detail::warp_size);
295-
auto const local_warp_id = static_cast<size_type>(threadIdx.x / cudf::detail::warp_size);
296-
auto const warp_lane = static_cast<size_type>(threadIdx.x % cudf::detail::warp_size);
298+
namespace cg = cooperative_groups;
299+
auto const block = cg::this_thread_block();
300+
auto const warp = cg::tiled_partition<cudf::detail::warp_size>(block);
301+
302+
auto const global_thread_id = cudf::detail::grid_1d::global_thread_id();
303+
auto const global_warp_id = static_cast<size_type>(global_thread_id / cudf::detail::warp_size);
304+
auto const local_warp_id = static_cast<size_type>(warp.meta_group_rank());
305+
auto const warp_lane = static_cast<size_type>(warp.thread_rank());
297306
auto const nwarps =
298307
static_cast<size_type>(cudf::detail::grid_1d::grid_stride() / cudf::detail::warp_size);
299308
char* in_chars_shared = temporary_buffer[local_warp_id];
@@ -326,7 +335,7 @@ CUDF_KERNEL void url_decode_char_replacer(column_device_view const in_strings,
326335
in_chars_shared[char_idx] = in_idx >= 0 && in_idx < string_length ? in_chars[in_idx] : 0;
327336
}
328337

329-
__syncwarp();
338+
warp.sync();
330339

331340
// `char_idx_start` represents the start character index of the current warp.
332341
for (size_type char_idx_start = 0; char_idx_start < string_length_block;
@@ -364,7 +373,7 @@ CUDF_KERNEL void url_decode_char_replacer(column_device_view const in_strings,
364373
out_idx[local_warp_id] += (out_offset + out_size);
365374
}
366375

367-
__syncwarp();
376+
warp.sync();
368377
}
369378
}
370379
}

cpp/src/strings/search/find.cu

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
2+
* Copyright (c) 2019-2025, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -33,6 +33,7 @@
3333
#include <rmm/exec_policy.hpp>
3434

3535
#include <cooperative_groups.h>
36+
#include <cooperative_groups/reduce.h>
3637
#include <cuda/atomic>
3738
#include <cuda/std/utility>
3839
#include <thrust/binary_search.h>
@@ -121,17 +122,12 @@ CUDF_KERNEL void finder_warp_parallel_fn(column_device_view const d_strings,
121122
size_type const stop,
122123
size_type* d_results)
123124
{
124-
auto const idx = cudf::detail::grid_1d::global_thread_id();
125-
126-
auto const str_idx = idx / cudf::detail::warp_size;
127-
if (str_idx >= d_strings.size()) { return; }
128-
auto const lane_idx = idx % cudf::detail::warp_size;
129-
130-
if (d_strings.is_null(str_idx)) { return; }
125+
namespace cg = cooperative_groups;
126+
auto const warp = cg::tiled_partition<cudf::detail::warp_size>(cg::this_thread_block());
127+
auto const lane_idx = warp.thread_rank();
131128

132-
// initialize the output for the atomicMin/Max
133-
if (lane_idx == 0) { d_results[str_idx] = forward ? std::numeric_limits<size_type>::max() : -1; }
134-
__syncwarp();
129+
auto const str_idx = warp.meta_group_rank();
130+
if (str_idx >= d_strings.size() or d_strings.is_null(str_idx)) { return; }
135131

136132
auto const d_str = d_strings.element<string_view>(str_idx);
137133
auto const d_target = d_targets[str_idx];
@@ -158,16 +154,12 @@ CUDF_KERNEL void finder_warp_parallel_fn(column_device_view const d_strings,
158154
}
159155

160156
// find stores the minimum position while rfind stores the maximum position
161-
// note that this was slightly faster than using cub::WarpReduce
162-
cuda::atomic_ref<size_type, cuda::thread_scope_block> ref{*(d_results + str_idx)};
163-
forward ? ref.fetch_min(position, cuda::std::memory_order_relaxed)
164-
: ref.fetch_max(position, cuda::std::memory_order_relaxed);
165-
__syncwarp();
157+
auto const result = forward ? cg::reduce(warp, position, cg::less<size_type>())
158+
: cg::reduce(warp, position, cg::greater<size_type>());
166159

167160
if (lane_idx == 0) {
168161
// the final result needs to be fixed up convert max() to -1
169162
// and a byte position to a character position
170-
auto const result = d_results[str_idx];
171163
d_results[str_idx] =
172164
((result < std::numeric_limits<size_type>::max()) && (result >= begin))
173165
? start_char_pos + characters_in_string(d_str.data() + begin, result - begin)

cpp/src/strings/slice.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
2+
* Copyright (c) 2019-2025, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -129,7 +129,7 @@ CUDF_KERNEL void substring_from_kernel(column_device_view const d_strings,
129129
itr += cudf::detail::warp_size;
130130
}
131131

132-
__syncwarp();
132+
warp.sync();
133133

134134
if (warp.thread_rank() == 0) {
135135
if (start >= char_count) {

0 commit comments

Comments
 (0)