Skip to content

Commit f23b306

Browse files
committed
CUDA: Add top-k implementation
1 parent ec047e1 commit f23b306

File tree

8 files changed

+189
-13
lines changed

8 files changed

+189
-13
lines changed

cmake/CPM.cmake

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# SPDX-License-Identifier: MIT
2+
#
3+
# SPDX-FileCopyrightText: Copyright (c) 2019-2023 Lars Melchior and contributors
4+
5+
# TODO: Remove this file once CCCL 3.2 is released & bundled with the CUDA Toolkit
6+
set(CPM_DOWNLOAD_VERSION 0.42.0)
7+
set(CPM_HASH_SUM "2020b4fc42dba44817983e06342e682ecfc3d2f484a581f11cc5731fbe4dce8a")
8+
9+
if(CPM_SOURCE_CACHE)
10+
set(CPM_DOWNLOAD_LOCATION "${CPM_SOURCE_CACHE}/cpm/CPM_${CPM_DOWNLOAD_VERSION}.cmake")
11+
elseif(DEFINED ENV{CPM_SOURCE_CACHE})
12+
set(CPM_DOWNLOAD_LOCATION "$ENV{CPM_SOURCE_CACHE}/cpm/CPM_${CPM_DOWNLOAD_VERSION}.cmake")
13+
else()
14+
set(CPM_DOWNLOAD_LOCATION "${CMAKE_BINARY_DIR}/cmake/CPM_${CPM_DOWNLOAD_VERSION}.cmake")
15+
endif()
16+
17+
# Expand relative path. This is important if the provided path contains a tilde (~)
18+
get_filename_component(CPM_DOWNLOAD_LOCATION ${CPM_DOWNLOAD_LOCATION} ABSOLUTE)
19+
20+
file(DOWNLOAD
21+
https://github.com/cpm-cmake/CPM.cmake/releases/download/v${CPM_DOWNLOAD_VERSION}/CPM.cmake
22+
${CPM_DOWNLOAD_LOCATION} EXPECTED_HASH SHA256=${CPM_HASH_SUM}
23+
)
24+
25+
include(${CPM_DOWNLOAD_LOCATION})

ggml/src/ggml-cuda/CMakeLists.txt

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,17 @@ cmake_minimum_required(VERSION 3.18) # for CMAKE_CUDA_ARCHITECTURES
22

33
find_package(CUDAToolkit)
44

5+
# Remove once CCCL 3.2 has been released and bundled with CUDA Toolkit
6+
if (GGML_CUDA_CUB_3DOT2)
7+
include(../../../cmake/CPM.cmake)
8+
# This will automatically clone CCCL from GitHub and make the exported cmake targets available
9+
CPMAddPackage(
10+
NAME CCCL
11+
GITHUB_REPOSITORY nvidia/cccl
12+
GIT_TAG v3.2.0-rc0 # Fetches the latest commit on the main branch
13+
)
14+
endif()
15+
516
if (CUDAToolkit_FOUND)
617
message(STATUS "CUDA Toolkit found")
718

@@ -102,13 +113,19 @@ if (CUDAToolkit_FOUND)
102113
# As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
103114
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas)
104115
else ()
116+
if (GGML_CUDA_CUB_3DOT2)
117+
target_link_libraries(ggml-cuda PRIVATE CCCL::CCCL)
118+
endif()
105119
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "10.1")
106120
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
107121
else()
108122
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static)
109123
endif()
110124
endif()
111125
else()
126+
if (GGML_CUDA_CUB_3DOT2)
127+
target_link_libraries(ggml-cuda PRIVATE CCCL::CCCL)
128+
endif()
112129
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas)
113130
endif()
114131

ggml/src/ggml-cuda/argsort.cu

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@ static __global__ void init_offsets(int * offsets, const int ncols, const int nr
2222
}
2323

2424
#ifdef GGML_CUDA_USE_CUB
25-
static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
26-
const float * x,
27-
int * dst,
28-
const int ncols,
29-
const int nrows,
30-
ggml_sort_order order,
31-
cudaStream_t stream) {
25+
void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
26+
const float * x,
27+
int * dst,
28+
const int ncols,
29+
const int nrows,
30+
ggml_sort_order order,
31+
cudaStream_t stream) {
3232
ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ncols * nrows);
3333
ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows);
3434
ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);
@@ -162,12 +162,12 @@ static int next_power_of_2(int x) {
162162
return n;
163163
}
164164

165-
static void argsort_f32_i32_cuda_bitonic(const float * x,
166-
int * dst,
167-
const int ncols,
168-
const int nrows,
169-
ggml_sort_order order,
170-
cudaStream_t stream) {
165+
void argsort_f32_i32_cuda_bitonic(const float * x,
166+
int * dst,
167+
const int ncols,
168+
const int nrows,
169+
ggml_sort_order order,
170+
cudaStream_t stream) {
171171
// bitonic sort requires ncols to be power of 2
172172
const int ncols_pad = next_power_of_2(ncols);
173173

ggml/src/ggml-cuda/argsort.cuh

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,19 @@
11
#include "common.cuh"
22

33
void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
4+
5+
#ifdef GGML_CUDA_USE_CUB
6+
void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
7+
const float * x,
8+
int * dst,
9+
const int ncols,
10+
const int nrows,
11+
ggml_sort_order order,
12+
cudaStream_t stream);
13+
#endif // GGML_CUDA_USE_CUB
14+
void argsort_f32_i32_cuda_bitonic(const float * x,
15+
int * dst,
16+
const int ncols,
17+
const int nrows,
18+
ggml_sort_order order,
19+
cudaStream_t stream);

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
#include "ggml-cuda/ssm-scan.cuh"
4545
#include "ggml-cuda/sum.cuh"
4646
#include "ggml-cuda/sumrows.cuh"
47+
#include "ggml-cuda/top-k.cuh"
4748
#include "ggml-cuda/mean.cuh"
4849
#include "ggml-cuda/tsembd.cuh"
4950
#include "ggml-cuda/topk-moe.cuh"
@@ -2694,6 +2695,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
26942695
case GGML_OP_SSM_SCAN:
26952696
ggml_cuda_op_ssm_scan(ctx, dst);
26962697
break;
2698+
case GGML_OP_TOP_K:
2699+
ggml_cuda_op_top_k(ctx, dst);
2700+
break;
26972701
case GGML_OP_ARGSORT:
26982702
ggml_cuda_op_argsort(ctx, dst);
26992703
break;
@@ -4233,6 +4237,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
42334237
case GGML_OP_CUMSUM:
42344238
case GGML_OP_SUM:
42354239
return ggml_is_contiguous_rows(op->src[0]);
4240+
case GGML_OP_TOP_K:
4241+
return true;
42364242
case GGML_OP_ARGSORT:
42374243
#ifndef GGML_CUDA_USE_CUB
42384244
return op->src[0]->ne[0] <= 1024;

ggml/src/ggml-cuda/top-k.cu

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
#include "argsort.cuh"
2+
#include "top-k.cuh"
3+
4+
#ifdef GGML_CUDA_USE_CUB
5+
# include <cub/cub.cuh>
6+
# if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2)
7+
# define CUB_TOP_K_AVAILABLE
8+
using namespace cub;
9+
# endif // CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2
10+
#endif // GGML_CUDA_USE_CUB
11+
12+
#ifdef CUB_TOP_K_AVAILABLE
13+
static __global__ void init_indices(int * indices, const int ncols) {
14+
const int col = blockIdx.x * blockDim.x + threadIdx.x;
15+
16+
if (col < ncols) {
17+
indices[col] = col;
18+
}
19+
}
20+
21+
static void top_k_cub(ggml_cuda_pool & pool,
22+
const float * src,
23+
int * dst,
24+
const int ncols,
25+
const int k,
26+
cudaStream_t stream) {
27+
auto requirements = cuda::execution::require(cuda::execution::determinism::not_guaranteed,
28+
cuda::execution::output_ordering::unsorted);
29+
auto stream_env = cuda::stream_ref{ stream };
30+
auto env = cuda::std::execution::env{ stream_env, requirements };
31+
32+
ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ncols);
33+
ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols);
34+
35+
int * temp_indices = temp_indices_alloc.get();
36+
float * temp_keys = temp_keys_alloc.get();
37+
38+
static const int block_size = 256;
39+
const dim3 grid_size((ncols + block_size - 1) / block_size, 1);
40+
init_indices<<<grid_size, block_size, 0, stream>>>(temp_indices, ncols);
41+
42+
CUDA_CHECK(cudaMemcpyAsync(temp_keys, src, ncols * sizeof(float), cudaMemcpyDeviceToDevice, stream));
43+
44+
size_t temp_storage_bytes = 0;
45+
DeviceTopK::MaxPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, ncols, k, env);
46+
47+
ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
48+
void * d_temp_storage = temp_storage_alloc.get();
49+
50+
DeviceTopK::MaxPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, ncols, k, env);
51+
}
52+
53+
#else
54+
55+
static int next_power_of_2(int x) {
56+
int n = 1;
57+
while (n < x) {
58+
n *= 2;
59+
}
60+
return n;
61+
}
62+
63+
#endif // CUB_TOP_K_AVAILABLE
64+
65+
void ggml_cuda_op_top_k(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
66+
const ggml_tensor * src0 = dst->src[0];
67+
const float * src0_d = (const float *) src0->data;
68+
int * dst_d = (int *) dst->data;
69+
cudaStream_t stream = ctx.stream();
70+
71+
// are these asserts truly necessary?
72+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
73+
GGML_ASSERT(dst->type == GGML_TYPE_I32);
74+
GGML_ASSERT(ggml_is_contiguous(src0));
75+
76+
const int64_t ncols = src0->ne[0];
77+
const int64_t nrows = ggml_nrows(src0);
78+
const int64_t k = dst->ne[0];
79+
ggml_cuda_pool & pool = ctx.pool();
80+
#ifdef CUB_TOP_K_AVAILABLE
81+
// TODO: Switch to `DeviceSegmentedTopK` for multi-row TopK once implemented
82+
// https://github.com/NVIDIA/cccl/issues/6391
83+
// TODO: investigate if there exists a point where parallelized argsort is faster than sequential top-k
84+
for (int i = 0; i < nrows; i++) {
85+
top_k_cub(pool, src0_d + i * ncols, dst_d + i * k, ncols, k, stream);
86+
}
87+
#else
88+
// Fall back to argsort + copy
89+
const int ncols_pad = next_power_of_2(ncols);
90+
const size_t shared_mem = ncols_pad * sizeof(int);
91+
const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
92+
93+
ggml_cuda_pool_alloc<int> temp_dst_alloc(pool, ncols * nrows);
94+
int * tmp_dst = temp_dst_alloc.get();
95+
96+
if (shared_mem > max_shared_mem || ncols > 1024) {
97+
argsort_f32_i32_cuda_cub(pool, src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream);
98+
} else {
99+
argsort_f32_i32_cuda_bitonic(src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream);
100+
}
101+
CUDA_CHECK(cudaMemcpy2DAsync(dst_d, k * sizeof(int), tmp_dst, ncols * sizeof(int), k * sizeof(int), nrows,
102+
cudaMemcpyDeviceToDevice, stream));
103+
#endif // CUB_TOP_K_AVAILABLE
104+
}

ggml/src/ggml-cuda/top-k.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#include "common.cuh"
2+
3+
void ggml_cuda_op_top_k(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

tests/test-backend-ops.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8035,6 +8035,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
80358035
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {200000, 1, 1, 1}));
80368036
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {200000, 16, 1, 1}));
80378037
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {65000, 16, 1, 1}, 40));
8038+
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {65000, 1, 1, 1}, 40));
8039+
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {65000, 1, 1, 1}, 1));
8040+
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {200000, 1, 1, 1}, 400));
8041+
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {200000, 1, 1, 1}, 40));
8042+
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {200000, 1, 1, 1}, 1));
80388043

80398044
return test_cases;
80408045
}

0 commit comments

Comments
 (0)