Skip to content

Commit 75d386a

Browse files
committed
Perf: 7.1k tokens/sec
1 parent a161900 commit 75d386a

File tree

2 files changed

+144
-102
lines changed

2 files changed

+144
-102
lines changed

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 141 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -226,27 +226,30 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
226226

227227
static int mmq_get_granularity_host(ggml_type type, const int mmq_x, const int cc) {
228228
if (amd_mma_available(cc)) {
229-
// 32x32 tile_C -> 32 ; 16x16 tile_C -> 16
230229
switch (type) {
231-
case GGML_TYPE_Q4_0: return 16; // vec_dot_q8_0_q8_1_mma
232-
case GGML_TYPE_Q4_1: return 16; // vec_dot_q8_1_q8_1_mma
233-
case GGML_TYPE_Q5_0: return 16; // vec_dot_q8_0_q8_1_mma
234-
case GGML_TYPE_Q5_1: return 16; // vec_dot_q8_1_q8_1_mma
235-
case GGML_TYPE_Q8_0: return 16; // vec_dot_q8_0_q8_1_mma
236-
case GGML_TYPE_Q2_K: return 32; // vec_dot_q2_K_q8_1_mma
237-
case GGML_TYPE_Q3_K: return 32; // vec_dot_q8_0_16_q8_1_mma
238-
case GGML_TYPE_Q4_K: return 16; // vec_dot_q8_1_q8_1_mma
239-
case GGML_TYPE_Q5_K: return 16; // vec_dot_q8_1_q8_1_mma
240-
case GGML_TYPE_Q6_K: return 32; // vec_dot_q6_K_q8_1_mma
241-
case GGML_TYPE_IQ2_XXS: return 16; // vec_dot_q8_0_q8_1_mma
242-
case GGML_TYPE_IQ2_XS: return 32; // vec_dot_q8_0_16_q8_1_mma
243-
case GGML_TYPE_IQ2_S: return 32; // vec_dot_q8_0_16_q8_1_mma
244-
case GGML_TYPE_IQ3_XXS: return 16; // vec_dot_q8_0_q8_1_mma
245-
case GGML_TYPE_IQ3_S: return 16; // vec_dot_q8_0_q8_1_mma
246-
case GGML_TYPE_IQ1_S: return 16; // vec_dot_q8_1_q8_1_mma
247-
case GGML_TYPE_IQ4_XS: return 16; // vec_dot_q8_0_q8_1_mma
248-
case GGML_TYPE_IQ4_NL: return 16; // vec_dot_q8_0_q8_1_mma
249-
default: return 0;
230+
// vec_dot_q8_0_q8_1_mma
231+
case GGML_TYPE_Q4_0:
232+
case GGML_TYPE_Q5_0:
233+
case GGML_TYPE_Q8_0:
234+
case GGML_TYPE_IQ2_XXS:
235+
case GGML_TYPE_IQ3_XXS:
236+
case GGML_TYPE_IQ3_S:
237+
case GGML_TYPE_IQ4_XS:
238+
case GGML_TYPE_IQ4_NL:
239+
return mmq_x >= 128 ? 32 : 16;
240+
// vec_dot_q8_1_q8_1_mma
241+
case GGML_TYPE_Q4_1:
242+
case GGML_TYPE_Q5_1:
243+
case GGML_TYPE_Q4_K:
244+
case GGML_TYPE_Q5_K:
245+
case GGML_TYPE_IQ1_S:
246+
return mmq_x >= 128 ? 32 : 16;
247+
case GGML_TYPE_Q2_K: // vec_dot_q2_K_q8_1_mma
248+
case GGML_TYPE_Q3_K: // vec_dot_q8_0_16_q8_1_mma
249+
case GGML_TYPE_Q6_K: // vec_dot_q6_K_q8_1_mma
250+
case GGML_TYPE_IQ2_XS: // vec_dot_q8_0_16_q8_1_mma
251+
case GGML_TYPE_IQ2_S: // vec_dot_q8_0_16_q8_1_mma
252+
return mmq_x >= 192 ? 64 : 32;
250253
}
251254
} else if (new_mma_available(cc) && mmq_x >= 48) {
252255
return 16;
@@ -256,26 +259,29 @@ static int mmq_get_granularity_host(ggml_type type, const int mmq_x, const int c
256259
#if defined(AMD_MMA_AVAILABLE)
257260
static constexpr __device__ int mmq_get_granularity_device(ggml_type type, const int mmq_x) {
258261
switch (type) {
259-
// 32x32 tile_C -> 32 ; 16x16 tile_C -> 16
260-
case GGML_TYPE_Q4_0: return 16;
261-
case GGML_TYPE_Q4_1: return 16;
262-
case GGML_TYPE_Q5_0: return 16;
263-
case GGML_TYPE_Q5_1: return 16;
264-
case GGML_TYPE_Q8_0: return 16;
265-
case GGML_TYPE_Q2_K: return 32;
266-
case GGML_TYPE_Q3_K: return 32;
267-
case GGML_TYPE_Q4_K: return 16;
268-
case GGML_TYPE_Q5_K: return 16;
269-
case GGML_TYPE_Q6_K: return 32;
270-
case GGML_TYPE_IQ2_XXS: return 16;
271-
case GGML_TYPE_IQ2_XS: return 32;
272-
case GGML_TYPE_IQ2_S: return 32;
273-
case GGML_TYPE_IQ3_XXS: return 16;
274-
case GGML_TYPE_IQ3_S: return 16;
275-
case GGML_TYPE_IQ1_S: return 16;
276-
case GGML_TYPE_IQ4_XS: return 16;
277-
case GGML_TYPE_IQ4_NL: return 16;
278-
default: return 0;
262+
// vec_dot_q8_0_q8_1_mma
263+
case GGML_TYPE_Q4_0:
264+
case GGML_TYPE_Q5_0:
265+
case GGML_TYPE_Q8_0:
266+
case GGML_TYPE_IQ2_XXS:
267+
case GGML_TYPE_IQ3_XXS:
268+
case GGML_TYPE_IQ3_S:
269+
case GGML_TYPE_IQ4_XS:
270+
case GGML_TYPE_IQ4_NL:
271+
return mmq_x >= 128 ? 32 : 16;
272+
// vec_dot_q8_1_q8_1_mma
273+
case GGML_TYPE_Q4_1:
274+
case GGML_TYPE_Q5_1:
275+
case GGML_TYPE_Q4_K:
276+
case GGML_TYPE_Q5_K:
277+
case GGML_TYPE_IQ1_S:
278+
return mmq_x >= 128 ? 32 : 16;
279+
case GGML_TYPE_Q2_K: // vec_dot_q2_K_q8_1_mma
280+
case GGML_TYPE_Q3_K: // vec_dot_q8_0_16_q8_1_mma
281+
case GGML_TYPE_Q6_K: // vec_dot_q6_K_q8_1_mma
282+
case GGML_TYPE_IQ2_XS: // vec_dot_q8_0_16_q8_1_mma
283+
case GGML_TYPE_IQ2_S: // vec_dot_q8_0_16_q8_1_mma
284+
return mmq_x >= 192 ? 64 : 32;
279285
}
280286
}
281287
#elif defined(NEW_MMA_AVAILABLE)
@@ -290,27 +296,30 @@ static constexpr __device__ int mmq_get_granularity_device(ggml_type type, const
290296

291297
static int get_mmq_nwarps_host(ggml_type type, const int cc) {
292298
if (amd_mma_available(cc)) {
293-
// 32x32 tile_C -> 4 ; 16x16 tile_C -> 8
294299
switch (type) {
295-
case GGML_TYPE_Q4_0: return 8;
296-
case GGML_TYPE_Q4_1: return 8;
297-
case GGML_TYPE_Q5_0: return 8;
298-
case GGML_TYPE_Q5_1: return 8;
299-
case GGML_TYPE_Q8_0: return 8;
300-
case GGML_TYPE_Q2_K: return 4;
301-
case GGML_TYPE_Q3_K: return 4;
302-
case GGML_TYPE_Q4_K: return 8;
303-
case GGML_TYPE_Q5_K: return 8;
304-
case GGML_TYPE_Q6_K: return 4;
305-
case GGML_TYPE_IQ2_XXS: return 8;
306-
case GGML_TYPE_IQ2_XS: return 4;
307-
case GGML_TYPE_IQ2_S: return 4;
308-
case GGML_TYPE_IQ3_XXS: return 8;
309-
case GGML_TYPE_IQ3_S: return 8;
310-
case GGML_TYPE_IQ1_S: return 8;
311-
case GGML_TYPE_IQ4_XS: return 8;
312-
case GGML_TYPE_IQ4_NL: return 8;
313-
default: return 0;
300+
// vec_dot_q8_0_q8_1_mma
301+
case GGML_TYPE_Q4_0:
302+
case GGML_TYPE_Q5_0:
303+
case GGML_TYPE_Q8_0:
304+
case GGML_TYPE_IQ2_XXS:
305+
case GGML_TYPE_IQ3_XXS:
306+
case GGML_TYPE_IQ3_S:
307+
case GGML_TYPE_IQ4_XS:
308+
case GGML_TYPE_IQ4_NL:
309+
return 8;
310+
// vec_dot_q8_1_q8_1_mma
311+
case GGML_TYPE_Q4_1:
312+
case GGML_TYPE_Q5_1:
313+
case GGML_TYPE_Q4_K:
314+
case GGML_TYPE_Q5_K:
315+
case GGML_TYPE_IQ1_S:
316+
return 8;
317+
case GGML_TYPE_Q2_K: // vec_dot_q2_K_q8_1_mma
318+
case GGML_TYPE_Q3_K: // vec_dot_q8_0_16_q8_1_mma
319+
case GGML_TYPE_Q6_K: // vec_dot_q6_K_q8_1_mma
320+
case GGML_TYPE_IQ2_XS: // vec_dot_q8_0_16_q8_1_mma
321+
case GGML_TYPE_IQ2_S: // vec_dot_q8_0_16_q8_1_mma
322+
return 4;
314323
}
315324
} else {
316325
return 8;
@@ -319,27 +328,30 @@ static int get_mmq_nwarps_host(ggml_type type, const int cc) {
319328

320329
#if defined(AMD_MMA_AVAILABLE)
321330
static constexpr __device__ int get_mmq_nwarps_device(ggml_type type) {
322-
// 32x32 tile_C -> 4 ; 16x16 tile_C -> 8
323331
switch (type) {
324-
case GGML_TYPE_Q4_0: return 8;
325-
case GGML_TYPE_Q4_1: return 8;
326-
case GGML_TYPE_Q5_0: return 8;
327-
case GGML_TYPE_Q5_1: return 8;
328-
case GGML_TYPE_Q8_0: return 8;
329-
case GGML_TYPE_Q2_K: return 4;
330-
case GGML_TYPE_Q3_K: return 4;
331-
case GGML_TYPE_Q4_K: return 8;
332-
case GGML_TYPE_Q5_K: return 8;
333-
case GGML_TYPE_Q6_K: return 4;
334-
case GGML_TYPE_IQ2_XXS: return 8;
335-
case GGML_TYPE_IQ2_XS: return 4;
336-
case GGML_TYPE_IQ2_S: return 4;
337-
case GGML_TYPE_IQ3_XXS: return 8;
338-
case GGML_TYPE_IQ3_S: return 8;
339-
case GGML_TYPE_IQ1_S: return 8;
340-
case GGML_TYPE_IQ4_XS: return 8;
341-
case GGML_TYPE_IQ4_NL: return 8;
342-
default: return 0;
332+
// vec_dot_q8_0_q8_1_mma
333+
case GGML_TYPE_Q4_0:
334+
case GGML_TYPE_Q5_0:
335+
case GGML_TYPE_Q8_0:
336+
case GGML_TYPE_IQ2_XXS:
337+
case GGML_TYPE_IQ3_XXS:
338+
case GGML_TYPE_IQ3_S:
339+
case GGML_TYPE_IQ4_XS:
340+
case GGML_TYPE_IQ4_NL:
341+
return 8;
342+
// vec_dot_q8_1_q8_1_mma
343+
case GGML_TYPE_Q4_1:
344+
case GGML_TYPE_Q5_1:
345+
case GGML_TYPE_Q4_K:
346+
case GGML_TYPE_Q5_K:
347+
case GGML_TYPE_IQ1_S:
348+
return 8;
349+
case GGML_TYPE_Q2_K: // vec_dot_q2_K_q8_1_mma
350+
case GGML_TYPE_Q3_K: // vec_dot_q8_0_16_q8_1_mma
351+
case GGML_TYPE_Q6_K: // vec_dot_q6_K_q8_1_mma
352+
case GGML_TYPE_IQ2_XS: // vec_dot_q8_0_16_q8_1_mma
353+
case GGML_TYPE_IQ2_S: // vec_dot_q8_0_16_q8_1_mma
354+
return 4;
343355
}
344356
}
345357
#else
@@ -896,35 +908,49 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
896908
typedef tile<16, 8, int> tile_B;
897909
typedef tile<16, 16, int> tile_C;
898910

911+
constexpr int granularity = mmq_get_granularity_device(GGML_TYPE_Q4_K, mmq_x);
912+
constexpr int rows_per_warp = granularity;
913+
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
914+
915+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
916+
899917
const int * x_qs = (const int *) x;
900918
const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K;
901919
const int * y_qs = (const int *) y + 4;
902920
const half2 * y_dm = (const half2 *) y;
903921

904-
const int i0 = threadIdx.y * tile_A::I;
922+
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
905923

906924
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
907925
const int k0 = k00 + k01;
908926

909-
tile_A A;
910-
load_ldmatrix(A, x_qs + i0*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
927+
tile_A A[ntx];
928+
#pragma unroll
929+
for (int n = 0; n < ntx; ++n) {
930+
load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
931+
}
911932

912933
#pragma unroll
913-
for (int j0 = 0; j0 < mmq_x; j0 += tile_C::J) {
934+
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
914935
tile_B B;
915936
load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
916937

917938
float2 dsB;
918939
const int j = j0 + tile_C::get_j(0);
919940
dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
920941

921-
tile_C C;
922-
mma(C, A, B);
942+
#pragma unroll
943+
for (int n = 0; n < ntx; ++n) {
944+
tile_C C;
945+
mma(C, A[n], B);
923946

924-
for (int l = 0; l < tile_C::ne; ++l) {
925-
float2 dmA = __half22float2(x_dm[(i0 + tile_C::get_i(l))*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
926-
sum[(j0/tile_C::J)*tile_C::ne + l] += dmA.x*dsB.x*C.x[l];
927-
sum[(j0/tile_C::J)*tile_C::ne + l] += dmA.y*dsB.y;
947+
#pragma unroll
948+
for (int l = 0; l < tile_C::ne; ++l) {
949+
const int i = i0 + n*tile_A::I + tile_C::get_i(l);
950+
float2 dmA = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
951+
sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.x*dsB.x*C.x[l];
952+
sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.y*dsB.y;
953+
}
928954
}
929955
}
930956
}
@@ -2067,34 +2093,48 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
20672093
typedef tile<32, 4, int> tile_B;
20682094
typedef tile<32, 32, int> tile_C;
20692095

2096+
constexpr int granularity = mmq_get_granularity_device(GGML_TYPE_Q6_K, mmq_x);
2097+
constexpr int rows_per_warp = granularity;
2098+
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
2099+
2100+
y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
2101+
20702102
const int * x_qs = (const int *) x;
20712103
const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
20722104
const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K;
20732105
const int * y_qs = (const int *) y + 4;
20742106
const float * y_df = (const float *) y;
20752107

2076-
const int i0 = threadIdx.y * tile_A::I;
2108+
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
20772109

20782110
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
20792111
const int k0 = k00 + k01;
20802112

2081-
tile_A A;
2082-
load_ldmatrix(A, x_qs + i0*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
2113+
tile_A A[ntx];
2114+
#pragma unroll
2115+
for (int n = 0; n < ntx; ++n) {
2116+
load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
2117+
}
20832118

2084-
for (int j0 = 0; j0 < mmq_x; j0 += tile_C::J) {
2119+
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
20852120
tile_B B;
20862121
load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
20872122

20882123
float dB;
20892124
const int j = j0 + tile_C::get_j(0);
20902125
dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
20912126

2092-
tile_C C;
2093-
mma(C, A, B);
2127+
#pragma unroll
2128+
for (int n = 0; n < ntx; ++n) {
2129+
tile_C C;
2130+
mma(C, A[n], B);
20942131

2095-
for (int l = 0; l < tile_C::ne; ++l) {
2096-
const int8_t * sc = (const int8_t *) (x_sc + (i0 + tile_C::get_i(l))*MMQ_MMA_TILE_X_K_Q6_K + k00/16);
2097-
sum[(j0/tile_C::J)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[(i0 + tile_C::get_i(l))*MMQ_MMA_TILE_X_K_Q6_K] * dB;
2132+
#pragma unroll
2133+
for (int l = 0; l < tile_C::ne; ++l) {
2134+
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
2135+
const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16);
2136+
sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB;
2137+
}
20982138
}
20992139
}
21002140
}
@@ -2618,7 +2658,8 @@ static __device__ __forceinline__ void mmq_write_back_mma(
26182658
constexpr int nwarps = get_mmq_nwarps_device(type);
26192659

26202660
#if defined(AMD_MMA_AVAILABLE)
2621-
typedef tile<granularity, granularity, int> tile_C;
2661+
constexpr int tileC_IJ = mmq_get_granularity_device(type, 0);
2662+
typedef tile<tileC_IJ, tileC_IJ, int> tile_C;
26222663
constexpr int rows_per_warp = granularity;
26232664
#else
26242665
typedef tile<16, 8, int> tile_C;

ggml/src/ggml-cuda/quantize.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@ static __global__ void quantize_q8_1(
3131
float amax = fabsf(xi);
3232
float sum = xi;
3333

34-
amax = warp_reduce_max(amax);
35-
sum = warp_reduce_sum(sum);
34+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
35+
amax = warp_reduce_max<warp_size>(amax);
36+
sum = warp_reduce_sum<warp_size>(sum);
3637

3738
const float d = amax / 127;
3839
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);

0 commit comments

Comments
 (0)