11#include " common.cuh"
22
3- static __device__ __forceinline__ void dequantize_q4_0 (const void * vx, const int64_t ib, const int iqs, float2 & v){
3+ static __device__ __forceinline__ void dequantize_q4_0 (const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
44 const block_q4_0 * x = (const block_q4_0 *) vx;
55
6- const float d = x[ib].d ;
6+ const dfloat d = x[ib].d ;
77
88 const int vui = x[ib].qs [iqs];
99
1010 v.x = vui & 0xF ;
1111 v.y = vui >> 4 ;
1212
13+ #ifdef GGML_CUDA_F16
14+ v = __hsub2 (v, {8 .0f , 8 .0f });
15+ v = __hmul2 (v, {d, d});
16+ #else
1317 v.x = (v.x - 8 .0f ) * d;
1418 v.y = (v.y - 8 .0f ) * d;
19+ #endif // GGML_CUDA_F16
1520}
1621
17- static __device__ __forceinline__ void dequantize_q4_1 (const void * vx, const int64_t ib, const int iqs, float2 & v){
22+ static __device__ __forceinline__ void dequantize_q4_1 (const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
1823 const block_q4_1 * x = (const block_q4_1 *) vx;
1924
20- const float2 dm = __half22float2 (x[ib].dm );
25+ const dfloat d = __low2half (x[ib].dm );
26+ const dfloat m = __high2half (x[ib].dm );
2127
2228 const int vui = x[ib].qs [iqs];
2329
2430 v.x = vui & 0xF ;
2531 v.y = vui >> 4 ;
2632
27- v.x = (v.x * dm.x ) + dm.y ;
28- v.y = (v.y * dm.x ) + dm.y ;
33+ #ifdef GGML_CUDA_F16
34+ v = __hmul2 (v, {d, d});
35+ v = __hadd2 (v, {m, m});
36+ #else
37+ v.x = (v.x * d) + m;
38+ v.y = (v.y * d) + m;
39+ #endif // GGML_CUDA_F16
2940}
3041
31- static __device__ __forceinline__ void dequantize_q5_0 (const void * vx, const int64_t ib, const int iqs, float2 & v){
42+ static __device__ __forceinline__ void dequantize_q5_0 (const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
3243 const block_q5_0 * x = (const block_q5_0 *) vx;
3344
34- const float d = x[ib].d ;
45+ const dfloat d = x[ib].d ;
3546
3647 uint32_t qh;
3748 memcpy (&qh, x[ib].qh , sizeof (qh));
@@ -42,14 +53,20 @@ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const in
4253 v.x = ((x[ib].qs [iqs] & 0xf ) | xh_0);
4354 v.y = ((x[ib].qs [iqs] >> 4 ) | xh_1);
4455
56+ #ifdef GGML_CUDA_F16
57+ v = __hsub2 (v, {16 .0f , 16 .0f });
58+ v = __hmul2 (v, {d, d});
59+ #else
4560 v.x = (v.x - 16 .0f ) * d;
4661 v.y = (v.y - 16 .0f ) * d;
62+ #endif // GGML_CUDA_F16
4763}
4864
49- static __device__ __forceinline__ void dequantize_q5_1 (const void * vx, const int64_t ib, const int iqs, float2 & v){
65+ static __device__ __forceinline__ void dequantize_q5_1 (const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
5066 const block_q5_1 * x = (const block_q5_1 *) vx;
5167
52- const float2 dm = __half22float2 (x[ib].dm );
68+ const dfloat d = __low2half (x[ib].dm );
69+ const dfloat m = __high2half (x[ib].dm );
5370
5471 uint32_t qh;
5572 memcpy (&qh, x[ib].qh , sizeof (qh));
@@ -60,18 +77,27 @@ static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const in
6077 v.x = ((x[ib].qs [iqs] & 0xf ) | xh_0);
6178 v.y = ((x[ib].qs [iqs] >> 4 ) | xh_1);
6279
63- v.x = (v.x * dm.x ) + dm.y ;
64- v.y = (v.y * dm.x ) + dm.y ;
80+ #ifdef GGML_CUDA_F16
81+ v = __hmul2 (v, {d, d});
82+ v = __hadd2 (v, {m, m});
83+ #else
84+ v.x = (v.x * d) + m;
85+ v.y = (v.y * d) + m;
86+ #endif // GGML_CUDA_F16
6587}
6688
67- static __device__ __forceinline__ void dequantize_q8_0 (const void * vx, const int64_t ib, const int iqs, float2 & v){
89+ static __device__ __forceinline__ void dequantize_q8_0 (const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
6890 const block_q8_0 * x = (const block_q8_0 *) vx;
6991
70- const float d = x[ib].d ;
92+ const dfloat d = x[ib].d ;
7193
7294 v.x = x[ib].qs [iqs + 0 ];
7395 v.y = x[ib].qs [iqs + 1 ];
7496
97+ #ifdef GGML_CUDA_F16
98+ v = __hmul2 (v, {d, d});
99+ #else
75100 v.x *= d;
76101 v.y *= d;
102+ #endif // GGML_CUDA_F16
77103}
0 commit comments