Skip to content

Commit a2feff9

Browse files
committed
Enable CUB
1 parent 7c84cc9 commit a2feff9

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ jobs:
357357
id: depends
358358
run: |
359359
sudo apt-get update
360-
sudo apt-get install -y build-essential git cmake rocblas-dev hipblas-dev libcurl4-openssl-dev
360+
sudo apt-get install -y build-essential git cmake rocblas-dev hipblas-dev hipcub-dev libcurl4-openssl-dev
361361
362362
- name: ccache
363363
uses: hendrikmuhs/[email protected]

ggml/src/ggml-cuda/sum.cu

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
1-
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
1+
#if !defined(GGML_USE_MUSA) && (CUDART_VERSION >= 11070 || defined(GGML_USE_HIP))
22
#define USE_CUB
3-
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
3+
#endif // !defined(GGML_USE_MUSA) && (CUDART_VERSION >= 11070 || defined(GGML_USE_HIP))
44

55
#ifdef USE_CUB
6+
7+
#if defined(GGML_USE_HIP)
8+
#include <hipcub/hipcub.hpp>
9+
using namespace hipcub;
10+
#else
611
#include <cub/cub.cuh>
712
using namespace cub;
13+
#endif // defined(GGML_USE_HIP)
14+
815
#endif // USE_CUB
916

1017
#include "sumrows.cuh"
@@ -20,7 +27,6 @@ void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int
2027
DeviceReduce::Sum(tmp_alloc.ptr, tmp_size, x, dst, ne, stream);
2128
#else
2229
// Use (inefficient) sum_rows implementation as a fallback.
23-
// For AMD there is rocPRIM which could be used as a drop-in replacement via hipcub but this would require C++11 -> C++14.
2430
sum_rows_f32_cuda(x, dst, ne, 1, stream);
2531
GGML_UNUSED(pool);
2632
#endif // USE_CUB

0 commit comments

Comments
 (0)