Skip to content

Commit e0b24f6

Browse files
q10facebook-github-bot
authored andcommitted
Migrate old CUB code to be compatible with CUB 3 (#4835)
Summary: X-link: facebookresearch/FBGEMM#1863 - Migrate old CUB code to be compatible with CUB 3 (which is introduced in CUDA 13). The radix sort API is updated in CUB 3 to remove debug_synchronous flag completely, and was already an unused flag as far back as CUB 2.2.0 (part of CUDA 12.3): https://github.com/NVIDIA/cccl/blob/v2.2.0/cub/cub/device/device_radix_sort.cuh https://docs.nvidia.com/cuda/archive/12.3.0/cuda-toolkit-release-notes/index.html However, it appears to be still in use in the ROCm equivalent: https://github.com/ROCm/rocm-libraries/blob/main/projects/rocprim/rocprim/include/rocprim/device/device_radix_sort.hpp - Add ROCm compatibility with CUB Min and Max by using https://github.com/ROCm/rocm-libraries/blob/main/projects/rocthrust/thrust/functional.h Pull Request resolved: #4835 Reviewed By: cthi Differential Revision: D81960727 Pulled By: q10 fbshipit-source-id: 0a282bfd3cef3d78078df23d93a8c5ab85b64cb8
1 parent 08ae0af commit e0b24f6

File tree

9 files changed

+75
-31
lines changed

9 files changed

+75
-31
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838

3939
#include "fbgemm_gpu/utils/cuda_block_count.h"
4040
#include "fbgemm_gpu/utils/cuda_prelude.cuh"
41+
#include "fbgemm_gpu/utils/device_sort.cuh"
4142
#include "fbgemm_gpu/utils/stochastic_rounding.cuh"
4243

4344
#if !( \
@@ -54,7 +55,6 @@
5455
#ifndef USE_ROCM
5556
#include <mma.h>
5657
#endif
57-
#include <cub/cub.cuh>
5858

5959
#include <torch/torch.h>
6060

@@ -1702,7 +1702,8 @@ __device__ float compute_max_block(
17021702
17031703
__shared__ typename BlockReduce::TempStorage temp_storage[THREAD_Y];
17041704
1705-
float amax = BlockReduce(temp_storage[threadIdx.y]).Reduce(xabs, cub::Max());
1705+
float amax =
1706+
BlockReduce(temp_storage[threadIdx.y]).Reduce(xabs, Max<float>());
17061707
17071708
__shared__ float amax_smem[THREAD_Y];
17081709
if (threadIdx.x == 0)
@@ -1724,7 +1725,7 @@ __device__ float compute_max_warp(
17241725
typedef cub::WarpReduce<float> WarpReduce;
17251726
__shared__ typename WarpReduce::TempStorage temp_storage[THREAD_Y];
17261727
1727-
float amax = WarpReduce(temp_storage[threadIdx.y]).Reduce(xabs, cub::Max());
1728+
float amax = WarpReduce(temp_storage[threadIdx.y]).Reduce(xabs, Max<float>());
17281729
amax = __shfl_sync(0xffffffff, amax, 0);
17291730
return amax;
17301731
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <cub/cub.cuh>
12+
13+
#ifdef USE_ROCM
14+
#include <thrust/functional.h>
15+
#else
16+
#include <cuda/functional>
17+
#endif
18+
19+
// clang-format off
20+
#include "fbgemm_gpu/utils/cub_namespace_prefix.cuh"
21+
#include <cub/device/device_radix_sort.cuh>
22+
#include <cub/device/device_scan.cuh>
23+
#include "fbgemm_gpu/utils/cub_namespace_postfix.cuh"
24+
// clang-format on
25+
26+
namespace fbgemm_gpu {
27+
28+
#ifdef USE_ROCM
29+
template <typename T>
30+
using Max = thrust::maximum<T>;
31+
#else
32+
#if CUDA_VERSION >= 13000
33+
template <typename T>
34+
using Max = cuda::maximum<T>;
35+
#else
36+
template <typename T>
37+
using Max = cub::Max;
38+
#endif
39+
#endif
40+
41+
#ifdef USE_ROCM
42+
template <typename T>
43+
using Min = thrust::minimum<T>;
44+
#else
45+
#if CUDA_VERSION >= 13000
46+
template <typename T>
47+
using Min = cuda::minimum<T>;
48+
#else
49+
template <typename T>
50+
using Min = cub::Min;
51+
#endif
52+
#endif
53+
54+
} // namespace fbgemm_gpu

fbgemm_gpu/src/jagged_tensor_ops/common.cuh

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,8 @@
1616
#include <torch/csrc/autograd/custom_function.h>
1717
#include <torch/library.h>
1818
#include <ATen/cuda/Atomic.cuh>
19-
#include <cub/cub.cuh>
2019

21-
// clang-format off
22-
#include "fbgemm_gpu/utils/cub_namespace_prefix.cuh"
23-
#include <cub/device/device_scan.cuh>
24-
#include "fbgemm_gpu/utils/cub_namespace_postfix.cuh"
25-
// clang-format on
20+
#include "fbgemm_gpu/utils/device_sort.cuh"
2621

2722
#include "common.h"
2823
#include "fbgemm_gpu/sparse_ops.h"

fbgemm_gpu/src/jagged_tensor_ops/jagged_softmax_forward.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ __global__ __launch_bounds__(kMaxThreads) void jagged_softmax_kernel(
6060

6161
// Collectively compute the block-wide max reduction
6262
scalar_t block_max_value =
63-
BlockReduceT(temp_storage).Reduce(thread_val, cub::Max());
63+
BlockReduceT(temp_storage).Reduce(thread_val, Max<index_t>());
6464
__syncthreads();
6565

6666
if (tid == 0) {

fbgemm_gpu/src/jagged_tensor_ops/jagged_unique_indices.cu

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,10 @@ __global__ __launch_bounds__(kMaxThreads) void unique_indices_length_kernel(
121121
t_min = (value < t_min) ? value : t_min;
122122
}
123123

124-
index_t block_max = BlockReduce(temp_storage_max).Reduce(t_max, cub::Max());
125-
index_t block_min = BlockReduce(temp_storage_min).Reduce(t_min, cub::Min());
124+
index_t block_max =
125+
BlockReduce(temp_storage_max).Reduce(t_max, Max<index_t>());
126+
index_t block_min =
127+
BlockReduce(temp_storage_min).Reduce(t_min, Min<index_t>());
126128
if (tid == 0) {
127129
block_results[0] = block_max;
128130
block_results[1] = block_min;
@@ -240,7 +242,8 @@ __global__ __launch_bounds__(kMaxThreads) void compute_hash_size_kernel(
240242
t_max = (value > t_max) ? value : t_max;
241243
}
242244

243-
index_t block_max = BlockReduce(temp_storage_max).Reduce(t_max, cub::Max());
245+
index_t block_max =
246+
BlockReduce(temp_storage_max).Reduce(t_max, Max<index_t>());
244247
if (tid == 0) {
245248
hash_size[bid] = block_max + 1;
246249
}

fbgemm_gpu/src/split_embeddings_cache/lfu_cache_find.cu

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,7 @@ std::pair<Tensor, Tensor> lfu_cache_find_uncached_cuda(
152152
N,
153153
0,
154154
int(log2(float(lxu_cache_state.size(0) + 1)) + 1) + kLFUCounterBits,
155-
at::cuda::getCurrentCUDAStream(),
156-
false));
155+
at::cuda::getCurrentCUDAStream()));
157156
auto temp_storage = at::empty(
158157
{static_cast<int64_t>(temp_storage_bytes)},
159158
unique_indices.options().dtype(at::kByte));
@@ -167,8 +166,7 @@ std::pair<Tensor, Tensor> lfu_cache_find_uncached_cuda(
167166
N,
168167
0,
169168
int(log2(float(lxu_cache_state.size(0) + 1)) + 1) + kLFUCounterBits,
170-
at::cuda::getCurrentCUDAStream(),
171-
false));
169+
at::cuda::getCurrentCUDAStream()));
172170
});
173171
return {sorted_cache_sets, cache_set_sorted_unique_indices};
174172
}

fbgemm_gpu/src/split_embeddings_cache/linearize_cache_indices.cu

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,7 @@ get_unique_indices_cuda_impl(
232232
N, \
233233
0, \
234234
int(log2(float(max_indices + 1)) + 1), \
235-
at::cuda::getCurrentCUDAStream(), \
236-
false))
235+
at::cuda::getCurrentCUDAStream()))
237236

238237
#define INVOKE_CUB_SORT_KEYS(TEMP_STORAGE_PTR) \
239238
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortKeys( \
@@ -244,8 +243,7 @@ get_unique_indices_cuda_impl(
244243
N, \
245244
0, \
246245
int(log2(float(max_indices + 1)) + 1), \
247-
at::cuda::getCurrentCUDAStream(), \
248-
false))
246+
at::cuda::getCurrentCUDAStream()))
249247

250248
#define INVOKE_CUB_ENCODE(TEMP_STORAGE_PTR) \
251249
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRunLengthEncode::Encode( \
@@ -256,8 +254,7 @@ get_unique_indices_cuda_impl(
256254
unique_indices_count->data_ptr<int32_t>(), \
257255
unique_indices_length.data_ptr<int32_t>(), \
258256
N, \
259-
at::cuda::getCurrentCUDAStream(), \
260-
false))
257+
at::cuda::getCurrentCUDAStream()))
261258

262259
#define INVOKE_CUB_UNIQUE(TEMP_STORAGE_PTR) \
263260
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceSelect::Unique( \
@@ -267,8 +264,7 @@ get_unique_indices_cuda_impl(
267264
unique_indices.data_ptr<index_t>(), \
268265
unique_indices_length.data_ptr<int32_t>(), \
269266
N, \
270-
at::cuda::getCurrentCUDAStream(), \
271-
false))
267+
at::cuda::getCurrentCUDAStream()))
272268

273269
AT_DISPATCH_INDEX_TYPES(
274270
linear_indices.scalar_type(), "get_unique_indices_cuda", [&] {

fbgemm_gpu/src/split_embeddings_cache/lru_cache_find.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,7 @@ lru_cache_find_uncached_cuda(
206206
N, \
207207
0, \
208208
int(log2(float(lxu_cache_state.size(0) + 1)) + 1), \
209-
at::cuda::getCurrentCUDAStream(), \
210-
false))
209+
at::cuda::getCurrentCUDAStream()))
211210

212211
AT_DISPATCH_INDEX_TYPES(
213212
unique_indices.scalar_type(), "lru_cache_find_uncached_cuda", [&] {

fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -313,8 +313,7 @@ transpose_embedding_input(
313313
linear_indices.numel(),
314314
0,
315315
total_hash_size_bits,
316-
at::cuda::getCurrentCUDAStream(),
317-
false));
316+
at::cuda::getCurrentCUDAStream()));
318317
auto temp_storage = at::empty(
319318
{static_cast<int64_t>(temp_storage_bytes)},
320319
indices.options().dtype(at::kByte));
@@ -329,8 +328,7 @@ transpose_embedding_input(
329328
linear_indices.numel(),
330329
0,
331330
total_hash_size_bits,
332-
at::cuda::getCurrentCUDAStream(),
333-
false));
331+
at::cuda::getCurrentCUDAStream()));
334332
#else
335333
using config = rocprim::radix_sort_config<
336334
rocprim::default_config,

0 commit comments

Comments
 (0)