Skip to content

Commit 8417c61

Browse files
committed
Enable CUB
1 parent 6b9a524 commit 8417c61

File tree

7 files changed

+40
-8
lines changed

7 files changed

+40
-8
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,7 @@ jobs:
558558
id: depends
559559
run: |
560560
sudo apt-get update
561-
sudo apt-get install -y build-essential git cmake rocblas-dev hipblas-dev libcurl4-openssl-dev rocwmma-dev
561+
sudo apt-get install -y build-essential git cmake rocblas-dev hipblas-dev hipcub-dev libcurl4-openssl-dev rocwmma-dev
562562
563563
- name: ccache
564564
uses: ggml-org/[email protected]

ggml/src/ggml-cuda/argsort.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
11
#include "argsort.cuh"
22

33
#ifdef GGML_CUDA_USE_CUB
4+
#if defined(GGML_USE_HIP)
5+
6+
7+
#include <hipcub/hipcub.hpp>
8+
using namespace hipcub;
9+
#else
410
# include <cub/cub.cuh>
511
using namespace cub;
12+
#endif // GGML_USE_HIP
13+
614
#endif // GGML_CUDA_USE_CUB
715

816
static __global__ void init_indices(int * indices, const int ncols, const int nrows) {

ggml/src/ggml-cuda/common.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,9 @@
9191
#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NG)
9292
#define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG)
9393

94-
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
94+
#if defined(GGML_USE_HIP) || (!defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070)
9595
# define GGML_CUDA_USE_CUB
96-
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
96+
#endif // defined(GGML_USE_HIP) || (!defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070)
9797

9898
#ifdef __CUDA_ARCH_LIST__
9999
constexpr bool ggml_cuda_has_arch_impl(int) {

ggml/src/ggml-cuda/mean.cu

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,15 @@
22
#include "reduce_rows.cuh"
33

44
#ifdef GGML_CUDA_USE_CUB
5+
6+
#if defined(GGML_USE_HIP)
7+
#include <hipcub/hipcub.hpp>
8+
using namespace hipcub;
9+
#else
510
#include <cub/cub.cuh>
611
using namespace cub;
12+
#endif // GGML_USE_HIP
13+
714
#endif // GGML_CUDA_USE_CUB
815

916
template <typename T> __global__ void divide_by_count(T * result, size_t count) {

ggml/src/ggml-cuda/ssm-scan.cu

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
1-
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
1+
#if defined(GGML_USE_HIP) || (!defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070)
22
#define USE_CUB
3-
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
3+
#endif // defined(GGML_USE_HIP) || (!defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070)
44

55
#ifdef USE_CUB
6+
7+
#if defined(GGML_USE_HIP)
8+
#include <hipcub/hipcub.hpp>
9+
using namespace hipcub;
10+
#else
611
#include <cub/cub.cuh>
712
using namespace cub;
13+
#endif // GGML_USE_HIP
14+
815
#endif // USE_CUB
916

1017
#include "ssm-scan.cuh"
@@ -48,8 +55,8 @@ __global__ void __launch_bounds__(splitD, 1)
4855
__shared__ float smemC[N];
4956

5057
#ifdef USE_CUB
51-
using BlockLoad = cub::BlockLoad<float, splitD, N, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
52-
using BlockStore = cub::BlockStore<float, splitD, N, cub::BLOCK_STORE_WARP_TRANSPOSE>;
58+
using BlockLoad = BlockLoad<float, splitD, N, BLOCK_LOAD_WARP_TRANSPOSE>;
59+
using BlockStore = BlockStore<float, splitD, N, BLOCK_STORE_WARP_TRANSPOSE>;
5360

5461
union CubTempStorage {
5562
typename BlockLoad::TempStorage load_temp;

ggml/src/ggml-cuda/sum.cu

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,16 @@
22
#include "sumrows.cuh"
33

44
#ifdef GGML_CUDA_USE_CUB
5+
6+
#if defined(GGML_USE_HIP)
7+
#include <hipcub/hipcub.hpp>
8+
using namespace hipcub;
9+
#else
510
#include <cub/cub.cuh>
611
using namespace cub;
12+
13+
#endif // GGML_USE_HIP
14+
715
#endif // GGML_CUDA_USE_CUB
816

917
#include <cstdint>
@@ -16,7 +24,6 @@ void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int
1624
DeviceReduce::Sum(tmp_alloc.ptr, tmp_size, x, dst, ne, stream);
1725
#else
1826
// Use (inefficient) sum_rows implementation as a fallback.
19-
// For AMD there is rocPRIM which could be used as a drop-in replacement via hipcub but this would require C++11 -> C++14.
2027
sum_rows_f32_cuda(x, dst, ne, 1, stream);
2128
GGML_UNUSED(pool);
2229
#endif // GGML_CUDA_USE_CUB

ggml/src/ggml-cuda/vendors/hip.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@
106106
#define cudaStreamPerThread hipStreamPerThread
107107
#define cudaStreamSynchronize hipStreamSynchronize
108108
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
109+
#define cudaStreamIsCapturing hipStreamIsCapturing
110+
#define cudaStreamCaptureStatus hipStreamCaptureStatus
111+
#define cudaStreamCaptureStatusNone hipStreamCaptureStatusNone
109112
#define cudaGraphExec_t hipGraphExec_t
110113
#define cudaGraphNode_t hipGraphNode_t
111114
#define cudaKernelNodeParams hipKernelNodeParams

0 commit comments

Comments
 (0)