Skip to content

Commit 7a6e91a

Browse files
CUDA: replace GGML_CUDA_F16 with CUDA arch checks (ggml-org#15433)
1 parent fec9519 commit 7a6e91a

File tree

12 files changed

+32
-86
lines changed

12 files changed

+32
-86
lines changed

docs/build.md

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -197,13 +197,12 @@ The environment variable `GGML_CUDA_ENABLE_UNIFIED_MEMORY=1` can be used to enab
197197

198198
The following compilation options are also available to tweak performance:
199199

200-
| Option | Legal values | Default | Description |
201-
|-------------------------------|------------------------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
202-
| GGML_CUDA_FORCE_MMQ | Boolean | false | Force the use of custom matrix multiplication kernels for quantized models instead of FP16 cuBLAS even if there is no int8 tensor core implementation available (affects V100, CDNA and RDNA3+). MMQ kernels are enabled by default on GPUs with int8 tensor core support. With MMQ force enabled, speed for large batch sizes will be worse but VRAM consumption will be lower. |
203-
| GGML_CUDA_FORCE_CUBLAS | Boolean | false | Force the use of FP16 cuBLAS instead of custom matrix multiplication kernels for quantized models |
204-
| GGML_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. |
205-
| GGML_CUDA_PEER_MAX_BATCH_SIZE | Positive integer | 128 | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial. |
206-
| GGML_CUDA_FA_ALL_QUANTS | Boolean | false | Compile support for all KV cache quantization type (combinations) for the FlashAttention CUDA kernels. More fine-grained control over KV cache size but compilation takes much longer. |
200+
| Option | Legal values | Default | Description |
201+
|-------------------------------|------------------------|---------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
202+
| GGML_CUDA_FORCE_MMQ | Boolean | false | Force the use of custom matrix multiplication kernels for quantized models instead of FP16 cuBLAS even if there is no int8 tensor core implementation available (affects V100, CDNA and RDNA3+). MMQ kernels are enabled by default on GPUs with int8 tensor core support. With MMQ force enabled, speed for large batch sizes will be worse but VRAM consumption will be lower. |
203+
| GGML_CUDA_FORCE_CUBLAS | Boolean | false | Force the use of FP16 cuBLAS instead of custom matrix multiplication kernels for quantized models. There may be issues with numerical overflows (except for CDNA and RDNA4) and memory use will be higher. Prompt processing may become faster on recent datacenter GPUs (the custom kernels were tuned primarily for RTX 3000/4000). |
204+
| GGML_CUDA_PEER_MAX_BATCH_SIZE | Positive integer | 128 | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial. |
205+
| GGML_CUDA_FA_ALL_QUANTS | Boolean | false | Compile support for all KV cache quantization type (combinations) for the FlashAttention CUDA kernels. More fine-grained control over KV cache size but compilation takes much longer. |
207206

208207
## MUSA
209208

docs/multimodal/MobileVLM.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ llama_print_timings: total time = 44411.01 ms / 377 tokens
194194
## Orin compile and run
195195
### compile
196196
```sh
197-
make GGML_CUDA=1 CUDA_DOCKER_ARCH=sm_87 GGML_CUDA_F16=1 -j 32
197+
make GGML_CUDA=1 CUDA_DOCKER_ARCH=sm_87 -j 32
198198
```
199199
### run on Orin
200200
### case 1

ggml/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,6 @@ option(GGML_CUDA "ggml: use CUDA"
158158
option(GGML_MUSA "ggml: use MUSA" OFF)
159159
option(GGML_CUDA_FORCE_MMQ "ggml: use mmq kernels instead of cuBLAS" OFF)
160160
option(GGML_CUDA_FORCE_CUBLAS "ggml: always use cuBLAS instead of mmq kernels" OFF)
161-
option(GGML_CUDA_F16 "ggml: use 16 bit floats for some calculations" OFF)
162161
set (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
163162
"ggml: max. batch size for using peer access")
164163
option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF)

ggml/src/ggml-cuda/CMakeLists.txt

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,6 @@ if (CUDAToolkit_FOUND)
2424
# for best performance and to also build real architectures for the most commonly used GPUs.
2525
if (GGML_NATIVE AND CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.6" AND CMAKE_VERSION VERSION_GREATER_EQUAL "3.24")
2626
set(CMAKE_CUDA_ARCHITECTURES "native")
27-
elseif(GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
28-
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8")
29-
set(CMAKE_CUDA_ARCHITECTURES "60-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real;89-real")
30-
else()
31-
set(CMAKE_CUDA_ARCHITECTURES "60-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real")
32-
endif()
3327
else()
3428
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8")
3529
set(CMAKE_CUDA_ARCHITECTURES "50-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real;89-real")
@@ -91,10 +85,6 @@ if (CUDAToolkit_FOUND)
9185
add_compile_definitions(GGML_CUDA_NO_FA)
9286
endif()
9387

94-
if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
95-
add_compile_definitions(GGML_CUDA_F16)
96-
endif()
97-
9888
if (GGML_CUDA_NO_PEER_COPY)
9989
add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
10090
endif()

ggml/src/ggml-cuda/common.cuh

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -206,14 +206,6 @@ static const char * cu_get_error_str(CUresult err) {
206206
#define GGML_CUDA_ASSUME(x)
207207
#endif // CUDART_VERSION >= 11010
208208

209-
#ifdef GGML_CUDA_F16
210-
typedef half dfloat; // dequantize float
211-
typedef half2 dfloat2;
212-
#else
213-
typedef float dfloat; // dequantize float
214-
typedef float2 dfloat2;
215-
#endif // GGML_CUDA_F16
216-
217209
#if (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
218210
#define GGML_USE_VMM
219211
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
@@ -559,7 +551,7 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
559551
#endif // CUDART_VERSION >= 12050
560552
}
561553

562-
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
554+
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, float2 & v);
563555

564556
static __device__ __forceinline__ float get_alibi_slope(
565557
const float max_bias, const uint32_t h, const uint32_t n_head_log2, const float m0, const float m1

ggml/src/ggml-cuda/convert.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
2727
const int64_t y_offset = qr == 1 ? 1 : qk/2;
2828

2929
// dequantize
30-
dfloat2 v;
30+
float2 v;
3131
dequantize_kernel(vx, ib, iqs, v);
3232

3333
const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs;

ggml/src/ggml-cuda/cpy.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
4242

4343
#pragma unroll
4444
for (int j = 0; j < QK8_0; j += 2) {
45-
dfloat2 dq;
45+
float2 dq;
4646
dequantize_q8_0(cxi, 0, j, dq);
4747
*(cdstf + j) = dq.x;
4848
*(cdstf + j + 1) = dq.y;
@@ -55,7 +55,7 @@ static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
5555

5656
#pragma unroll
5757
for (int j = 0; j < qk/2; j++) {
58-
dfloat2 dq;
58+
float2 dq;
5959
dequant(cxi, 0, j, dq);
6060
*(cdstf + j) = dq.x;
6161
*(cdstf + j + qk/2) = dq.y;

ggml/src/ggml-cuda/dequantize.cuh

Lines changed: 14 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,37 @@
11
#include "common.cuh"
22

3-
static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
3+
static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
44
const block_q4_0 * x = (const block_q4_0 *) vx;
55

6-
const dfloat d = x[ib].d;
6+
const float d = x[ib].d;
77

88
const int vui = x[ib].qs[iqs];
99

1010
v.x = vui & 0xF;
1111
v.y = vui >> 4;
1212

13-
#ifdef GGML_CUDA_F16
14-
v = __hsub2(v, {8.0f, 8.0f});
15-
v = __hmul2(v, {d, d});
16-
#else
1713
v.x = (v.x - 8.0f) * d;
1814
v.y = (v.y - 8.0f) * d;
19-
#endif // GGML_CUDA_F16
2015
}
2116

22-
static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
17+
static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int64_t ib, const int iqs, float2 & v){
2318
const block_q4_1 * x = (const block_q4_1 *) vx;
2419

25-
const dfloat d = __low2half(x[ib].dm);
26-
const dfloat m = __high2half(x[ib].dm);
20+
const float2 dm = __half22float2(x[ib].dm);
2721

2822
const int vui = x[ib].qs[iqs];
2923

3024
v.x = vui & 0xF;
3125
v.y = vui >> 4;
3226

33-
#ifdef GGML_CUDA_F16
34-
v = __hmul2(v, {d, d});
35-
v = __hadd2(v, {m, m});
36-
#else
37-
v.x = (v.x * d) + m;
38-
v.y = (v.y * d) + m;
39-
#endif // GGML_CUDA_F16
27+
v.x = (v.x * dm.x) + dm.y;
28+
v.y = (v.y * dm.x) + dm.y;
4029
}
4130

42-
static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
31+
static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
4332
const block_q5_0 * x = (const block_q5_0 *) vx;
4433

45-
const dfloat d = x[ib].d;
34+
const float d = x[ib].d;
4635

4736
uint32_t qh;
4837
memcpy(&qh, x[ib].qh, sizeof(qh));
@@ -53,20 +42,14 @@ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const in
5342
v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
5443
v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
5544

56-
#ifdef GGML_CUDA_F16
57-
v = __hsub2(v, {16.0f, 16.0f});
58-
v = __hmul2(v, {d, d});
59-
#else
6045
v.x = (v.x - 16.0f) * d;
6146
v.y = (v.y - 16.0f) * d;
62-
#endif // GGML_CUDA_F16
6347
}
6448

65-
static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
49+
static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int64_t ib, const int iqs, float2 & v){
6650
const block_q5_1 * x = (const block_q5_1 *) vx;
6751

68-
const dfloat d = __low2half(x[ib].dm);
69-
const dfloat m = __high2half(x[ib].dm);
52+
const float2 dm = __half22float2(x[ib].dm);
7053

7154
uint32_t qh;
7255
memcpy(&qh, x[ib].qh, sizeof(qh));
@@ -77,27 +60,18 @@ static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const in
7760
v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
7861
v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
7962

80-
#ifdef GGML_CUDA_F16
81-
v = __hmul2(v, {d, d});
82-
v = __hadd2(v, {m, m});
83-
#else
84-
v.x = (v.x * d) + m;
85-
v.y = (v.y * d) + m;
86-
#endif // GGML_CUDA_F16
63+
v.x = (v.x * dm.x) + dm.y;
64+
v.y = (v.y * dm.x) + dm.y;
8765
}
8866

89-
static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
67+
static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
9068
const block_q8_0 * x = (const block_q8_0 *) vx;
9169

92-
const dfloat d = x[ib].d;
70+
const float d = x[ib].d;
9371

9472
v.x = x[ib].qs[iqs + 0];
9573
v.y = x[ib].qs[iqs + 1];
9674

97-
#ifdef GGML_CUDA_F16
98-
v = __hmul2(v, {d, d});
99-
#else
10075
v.x *= d;
10176
v.y *= d;
102-
#endif // GGML_CUDA_F16
10377
}

ggml/src/ggml-cuda/getrows.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ static __global__ void k_get_rows(
3232
const int y_offset = qr == 1 ? 1 : qk/2;
3333

3434
// dequantize
35-
dfloat2 v;
35+
float2 v;
3636
dequantize_kernel(src0_row, ib, iqs, v);
3737

3838
dst_row[iybs + iqs + 0] = ggml_cuda_cast<dst_t>(v.x);

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3672,10 +3672,6 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t
36723672
features.push_back({ "NO_PEER_COPY", "1" });
36733673
#endif
36743674

3675-
#ifdef GGML_CUDA_F16
3676-
features.push_back({ "F16", "1" });
3677-
#endif
3678-
36793675
#ifdef GGML_CUDA_USE_GRAPHS
36803676
features.push_back({ "USE_GRAPHS", "1" });
36813677
#endif

0 commit comments

Comments
 (0)