Skip to content

Commit e6bfc3f

Browse files
committed
Revert "CUDA: replace GGML_CUDA_F16 with CUDA arch checks (ggml-org#15433)"
1 parent 86bafd0 commit e6bfc3f

File tree

7 files changed

+63
-25
lines changed

7 files changed

+63
-25
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,14 @@ static const char * cu_get_error_str(CUresult err) {
210210
#define GGML_CUDA_ASSUME(x)
211211
#endif // CUDART_VERSION >= 11010
212212

213+
#ifdef GGML_CUDA_F16
214+
typedef half dfloat; // dequantize float
215+
typedef half2 dfloat2;
216+
#else
217+
typedef float dfloat; // dequantize float
218+
typedef float2 dfloat2;
219+
#endif // GGML_CUDA_F16
220+
213221
#if (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
214222
#define GGML_USE_VMM
215223
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
@@ -555,7 +563,7 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
555563
#endif // CUDART_VERSION >= 12050
556564
}
557565

558-
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, float2 & v);
566+
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
559567

560568
static __device__ __forceinline__ float get_alibi_slope(
561569
const float max_bias, const uint32_t h, const uint32_t n_head_log2, const float m0, const float m1

ggml/src/ggml-cuda/convert.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
2727
const int64_t y_offset = qr == 1 ? 1 : qk/2;
2828

2929
// dequantize
30-
float2 v;
30+
dfloat2 v;
3131
dequantize_kernel(vx, ib, iqs, v);
3232

3333
const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs;

ggml/src/ggml-cuda/cpy.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
4242

4343
#pragma unroll
4444
for (int j = 0; j < QK8_0; j += 2) {
45-
float2 dq;
45+
dfloat2 dq;
4646
dequantize_q8_0(cxi, 0, j, dq);
4747
*(cdstf + j) = dq.x;
4848
*(cdstf + j + 1) = dq.y;
@@ -55,7 +55,7 @@ static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
5555

5656
#pragma unroll
5757
for (int j = 0; j < qk/2; j++) {
58-
float2 dq;
58+
dfloat2 dq;
5959
dequant(cxi, 0, j, dq);
6060
*(cdstf + j) = dq.x;
6161
*(cdstf + j + qk/2) = dq.y;

ggml/src/ggml-cuda/dequantize.cuh

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,48 @@
11
#include "common.cuh"
22

3-
static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
3+
static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
44
const block_q4_0 * x = (const block_q4_0 *) vx;
55

6-
const float d = x[ib].d;
6+
const dfloat d = x[ib].d;
77

88
const int vui = x[ib].qs[iqs];
99

1010
v.x = vui & 0xF;
1111
v.y = vui >> 4;
1212

13+
#ifdef GGML_CUDA_F16
14+
v = __hsub2(v, {8.0f, 8.0f});
15+
v = __hmul2(v, {d, d});
16+
#else
1317
v.x = (v.x - 8.0f) * d;
1418
v.y = (v.y - 8.0f) * d;
19+
#endif // GGML_CUDA_F16
1520
}
1621

17-
static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int64_t ib, const int iqs, float2 & v){
22+
static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
1823
const block_q4_1 * x = (const block_q4_1 *) vx;
1924

20-
const float2 dm = __half22float2(x[ib].dm);
25+
const dfloat d = __low2half(x[ib].dm);
26+
const dfloat m = __high2half(x[ib].dm);
2127

2228
const int vui = x[ib].qs[iqs];
2329

2430
v.x = vui & 0xF;
2531
v.y = vui >> 4;
2632

27-
v.x = (v.x * dm.x) + dm.y;
28-
v.y = (v.y * dm.x) + dm.y;
33+
#ifdef GGML_CUDA_F16
34+
v = __hmul2(v, {d, d});
35+
v = __hadd2(v, {m, m});
36+
#else
37+
v.x = (v.x * d) + m;
38+
v.y = (v.y * d) + m;
39+
#endif // GGML_CUDA_F16
2940
}
3041

31-
static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
42+
static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
3243
const block_q5_0 * x = (const block_q5_0 *) vx;
3344

34-
const float d = x[ib].d;
45+
const dfloat d = x[ib].d;
3546

3647
uint32_t qh;
3748
memcpy(&qh, x[ib].qh, sizeof(qh));
@@ -42,14 +53,20 @@ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const in
4253
v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
4354
v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
4455

56+
#ifdef GGML_CUDA_F16
57+
v = __hsub2(v, {16.0f, 16.0f});
58+
v = __hmul2(v, {d, d});
59+
#else
4560
v.x = (v.x - 16.0f) * d;
4661
v.y = (v.y - 16.0f) * d;
62+
#endif // GGML_CUDA_F16
4763
}
4864

49-
static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int64_t ib, const int iqs, float2 & v){
65+
static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
5066
const block_q5_1 * x = (const block_q5_1 *) vx;
5167

52-
const float2 dm = __half22float2(x[ib].dm);
68+
const dfloat d = __low2half(x[ib].dm);
69+
const dfloat m = __high2half(x[ib].dm);
5370

5471
uint32_t qh;
5572
memcpy(&qh, x[ib].qh, sizeof(qh));
@@ -60,18 +77,27 @@ static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const in
6077
v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
6178
v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
6279

63-
v.x = (v.x * dm.x) + dm.y;
64-
v.y = (v.y * dm.x) + dm.y;
80+
#ifdef GGML_CUDA_F16
81+
v = __hmul2(v, {d, d});
82+
v = __hadd2(v, {m, m});
83+
#else
84+
v.x = (v.x * d) + m;
85+
v.y = (v.y * d) + m;
86+
#endif // GGML_CUDA_F16
6587
}
6688

67-
static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
89+
static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
6890
const block_q8_0 * x = (const block_q8_0 *) vx;
6991

70-
const float d = x[ib].d;
92+
const dfloat d = x[ib].d;
7193

7294
v.x = x[ib].qs[iqs + 0];
7395
v.y = x[ib].qs[iqs + 1];
7496

97+
#ifdef GGML_CUDA_F16
98+
v = __hmul2(v, {d, d});
99+
#else
75100
v.x *= d;
76101
v.y *= d;
102+
#endif // GGML_CUDA_F16
77103
}

ggml/src/ggml-cuda/getrows.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ static __global__ void k_get_rows(
3232
const int y_offset = qr == 1 ? 1 : qk/2;
3333

3434
// dequantize
35-
float2 v;
35+
dfloat2 v;
3636
dequantize_kernel(src0_row, ib, iqs, v);
3737

3838
dst_row[iybs + iqs + 0] = ggml_cuda_cast<dst_t>(v.x);

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3691,6 +3691,10 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t
36913691
features.push_back({ "NO_PEER_COPY", "1" });
36923692
#endif
36933693

3694+
#ifdef GGML_CUDA_F16
3695+
features.push_back({ "F16", "1" });
3696+
#endif
3697+
36943698
#ifdef GGML_CUDA_USE_GRAPHS
36953699
features.push_back({ "USE_GRAPHS", "1" });
36963700
#endif

ggml/src/ggml-cuda/vecdotq.cuh

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp
8787
sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi);
8888
}
8989

90-
#ifdef FAST_FP16_AVAILABLE
90+
#ifdef GGML_CUDA_F16
9191
const float2 tmp = __half22float2(__hmul2(dm4, ds8));
9292
const float d4d8 = tmp.x;
9393
const float m4s8 = tmp.y;
@@ -96,7 +96,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp
9696
const float2 ds8f = __half22float2(ds8);
9797
const float d4d8 = dm4f.x * ds8f.x;
9898
const float m4s8 = dm4f.y * ds8f.y;
99-
#endif // FAST_FP16_AVAILABLE
99+
#endif // GGML_CUDA_F16
100100

101101
// scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it
102102
return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1));
@@ -158,7 +158,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp
158158
sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values
159159
}
160160

161-
#ifdef FAST_FP16_AVAILABLE
161+
#ifdef GGML_CUDA_F16
162162
const float2 tmp = __half22float2(__hmul2(dm5, ds8));
163163
const float d5d8 = tmp.x;
164164
const float m5s8 = tmp.y;
@@ -167,7 +167,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp
167167
const float2 ds8f = __half22float2(ds8);
168168
const float d5d8 = dm5f.x * ds8f.x;
169169
const float m5s8 = dm5f.y * ds8f.y;
170-
#endif // FAST_FP16_AVAILABLE
170+
#endif // GGML_CUDA_F16
171171

172172
// scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it
173173
return sumi*d5d8 + m5s8 / (QI5_1 / vdr);
@@ -201,7 +201,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
201201
sumi = ggml_cuda_dp4a(v[i], u[i], sumi);
202202
}
203203

204-
#ifdef FAST_FP16_AVAILABLE
204+
#ifdef GGML_CUDA_F16
205205
const float2 tmp = __half22float2(__hmul2(dm8, ds8));
206206
const float d8d8 = tmp.x;
207207
const float m8s8 = tmp.y;
@@ -210,7 +210,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
210210
const float2 ds8f = __half22float2(ds8);
211211
const float d8d8 = dm8f.x * ds8f.x;
212212
const float m8s8 = dm8f.y * ds8f.y;
213-
#endif // FAST_FP16_AVAILABLE
213+
#endif // GGML_CUDA_F16
214214

215215
// scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it
216216
return sumi*d8d8 + m8s8 / (QI8_1 / vdr);

0 commit comments

Comments
 (0)