Skip to content

Commit 0cc32ff

Browse files
ikawrakowIwan Kawrakow
andauthored
CUDA: muh faster prompt processing for MoE models and small u-batch sizes (#728)
* WIP: adding mainline mmq_id implementation * This seems to work * Now also -fmoe works * WIP * WIP * WIP * This works for mainline supported quants * mmq_id: add iq2_k, iq2_k_r4 * mmiq_id: don't assume row size is multiple of type size (per row scales) * mmiq_id: don't assume row size is multiple of type size * mmq_id: add iq2_ks So we are sure it works with per row scales * mmq_id: add iq2_kl * mmq_id: add iq3_ks * mmq_id: adding iq3_k, iq3_k_r4 * mmq_id: add iq4_kss, iq4_ks, iq4_ks_r4 * mmq_id: adding iq4_k, iq4_k_r4 * mmq_id: adding iq5_ks, iq5_ks_r4 * mmq_id: adding iq5_k, iq5_k_r4, q6_0 * mmq_id: adding iq6_k * mmq_id: add iq1_s_r4 * mmq_id: adding iq1_kt, iq2_kt * mmq_id: add iq3_kt, iq4_kt * Add CUDA fp8 header --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 50f7119 commit 0cc32ff

40 files changed

+6945
-193
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 216 additions & 193 deletions
Large diffs are not rendered by default.

ggml/src/ggml-cuda/mmq_id.cu

Lines changed: 564 additions & 0 deletions
Large diffs are not rendered by default.

ggml/src/ggml-cuda/mmq_id.cuh

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#pragma once
2+
3+
#include "common.cuh"
4+
5+
void ggml_cuda_mul_mat_q_id(
6+
ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids,
7+
ggml_tensor * dst, char * ids_data, char * src1_quantized_data);
8+
9+
void compute_row_ids(const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds,
10+
int64_t ne02, int64_t ne12, int64_t n_expert_used, int64_t ne11, int64_t nb11, int64_t nb12, int64_t nb21, cudaStream_t stream);
11+
12+
bool ggml_cuda_can_use_mmq_id(enum ggml_type type, int cc, int64_t ne11);

ggml/src/ggml-cuda/mmq_id_common.cuh

Lines changed: 4183 additions & 0 deletions
Large diffs are not rendered by default.

ggml/src/ggml-cuda/quantize_id.cu

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
#include "quantize_id.cuh"
2+
#include "mmq.cuh"
3+
#include <cstdint>
4+
5+
template <mmq_q8_1_ds_layout ds_layout>
6+
static __global__ void quantize_mmq_q8_1(
7+
const float * __restrict__ x, const int32_t * __restrict__ ids, void * __restrict__ vy,
8+
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
9+
const int64_t ne0, const int ne1, const int ne2) {
10+
11+
constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
12+
constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
13+
14+
const int64_t i0 = ((int64_t)blockDim.x*blockIdx.y + threadIdx.x)*4;
15+
16+
if (i0 >= ne0) {
17+
return;
18+
}
19+
20+
const int64_t i1 = blockIdx.x;
21+
const int64_t i2 = blockIdx.z % ne2;
22+
const int64_t i3 = blockIdx.z / ne2;
23+
24+
const int64_t i00 = i0;
25+
const int64_t i01 = ids ? ids[i1] : i1;
26+
const int64_t i02 = i2;
27+
const int64_t i03 = i3;
28+
29+
const float4 * x4 = (const float4 *) x;
30+
31+
block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
32+
33+
const int64_t ib0 = blockIdx.z*((int64_t)gridDim.x*gridDim.y*blockDim.x/QK8_1); // first block of channel
34+
const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.x; // block index in channel
35+
const int64_t iqs = i0 % (4*QK8_1); // quant index in block
36+
37+
// Load 4 floats per thread and calculate max. abs. value between them:
38+
const float4 xi = i0 < ne00 ? x4[(i03*s03 + i02*s02 + i01*s01 + i00)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f);
39+
float amax = fabsf(xi.x);
40+
amax = fmaxf(amax, fabsf(xi.y));
41+
amax = fmaxf(amax, fabsf(xi.z));
42+
amax = fmaxf(amax, fabsf(xi.w));
43+
44+
// Exchange max. abs. value between vals_per_scale/4 threads.
45+
#pragma unroll
46+
for (int offset = vals_per_scale/8; offset > 0; offset >>= 1) {
47+
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, offset, WARP_SIZE));
48+
}
49+
50+
float sum;
51+
if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) {
52+
sum = xi.x + xi.y + xi.z + xi.w;
53+
54+
// Calculate sums across vals_per_sum/4 threads.
55+
#pragma unroll
56+
for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) {
57+
sum += __shfl_xor_sync(0xFFFFFFFF, sum, offset, WARP_SIZE);
58+
}
59+
}
60+
61+
const float d_inv = 127.0f / amax;
62+
char4 q;
63+
q.x = roundf(xi.x*d_inv);
64+
q.y = roundf(xi.y*d_inv);
65+
q.z = roundf(xi.z*d_inv);
66+
q.w = roundf(xi.w*d_inv);
67+
68+
// Write back 4 int8 values as a single 32 bit value for better memroy bandwidth:
69+
char4 * yqs4 = (char4 *) y[ib].qs;
70+
yqs4[iqs/4] = q;
71+
72+
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6) {
73+
if (iqs % 16 != 0 || iqs >= 96) {
74+
return;
75+
}
76+
77+
y[ib].d2s6[2 + iqs/16] = sum;
78+
79+
if (iqs % 64 != 0) {
80+
return;
81+
}
82+
83+
const float d = 1.0f / d_inv;
84+
85+
y[ib].d2s6[iqs/64] = d;
86+
87+
return;
88+
}
89+
90+
if (iqs % 32 != 0) {
91+
return;
92+
}
93+
94+
const float d = 1.0f / d_inv;
95+
96+
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_DS4) {
97+
y[ib].ds4[iqs/32] = make_half2(d, sum);
98+
} else {
99+
y[ib].d4[iqs/32] = d;
100+
}
101+
}
102+
103+
void quantize_mmq_q8_1_cuda_id(
104+
const float * x, const int32_t * ids, void * vy, const ggml_type type_src0,
105+
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
106+
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
107+
GGML_ASSERT(ne00 % 4 == 0);
108+
GGML_ASSERT(ne0 % (4*QK8_1) == 0);
109+
GGML_ASSERT(ids);
110+
111+
// ne1 tends to assume the highest values, therefore use it as the "x" dimension of the CUDA grid:
112+
const int64_t block_num_y = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
113+
const dim3 num_blocks(ne1, block_num_y, ne2*ne3);
114+
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
115+
switch (mmq_get_q8_1_ds_layout(type_src0)) {
116+
case MMQ_Q8_1_DS_LAYOUT_D4:
117+
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4>
118+
<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
119+
break;
120+
case MMQ_Q8_1_DS_LAYOUT_DS4:
121+
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_DS4>
122+
<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
123+
break;
124+
case MMQ_Q8_1_DS_LAYOUT_D2S6:
125+
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D2S6>
126+
<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
127+
break;
128+
default:
129+
GGML_ABORT("fatal error");
130+
break;
131+
}
132+
}

ggml/src/ggml-cuda/quantize_id.cuh

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#pragma once
2+
3+
#include "common.cuh"
4+
5+
#include <cstdint>
6+
7+
#define CUDA_QUANTIZE_BLOCK_SIZE 256
8+
#define CUDA_QUANTIZE_BLOCK_SIZE_MMQ 128
9+
10+
//static_assert(MATRIX_ROW_PADDING % CUDA_QUANTIZE_BLOCK_SIZE == 0, "Risk of out-of-bounds access.");
11+
//static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access.");
12+
13+
void quantize_mmq_q8_1_cuda_id(
14+
const float * x, const int32_t * ids, void * vy,
15+
ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
16+
int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../mmq_id_common.cuh"
4+
5+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_kt(
6+
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
7+
8+
constexpr int nwarps = mmq_get_nwarps_device();
9+
10+
constexpr uint32_t ka = 0xCBAC1FED;
11+
constexpr uint32_t km = 0x3f3f3f3f;
12+
13+
#ifdef INT8_MMA_AVAILABLE
14+
int * x_qs = (int *) x_tile;
15+
float * x_df = (float *) (x_qs + WARP_SIZE*2);
16+
#else
17+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
18+
int * x_qs = (int *) x_tile;
19+
float * x_df = (float *) (x_qs + txs.qs);
20+
#endif // INT8_MMA_AVAILABLE
21+
22+
const int kqsx = threadIdx.x;
23+
24+
#pragma unroll
25+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
26+
int i = i0 + threadIdx.y;
27+
28+
if (need_check) {
29+
i = min(i, i_max);
30+
}
31+
32+
const block_iq1_kt * bxi = (const block_iq1_kt *)(x + i*stride + sizeof(float)) + kbx0;
33+
34+
int ib32 = kqsx/4;
35+
int j = kqsx%4;
36+
uint32_t val = bxi->ql[kqsx] + ((bxi->qh[kqsx%16] << (8 - 4*(kqsx/16))) & 0xf00) + ((bxi->sh[kqsx/4] << (8 - (kqsx%4))) & 0x1000) + 4096;
37+
int2 v = {0, 0};
38+
for (int k = 0; k < 4; ++k) {
39+
val *= ka;
40+
v.x |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k;
41+
}
42+
for (int k = 0; k < 4; ++k) {
43+
val *= ka;
44+
v.y |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k;
45+
}
46+
#ifdef INT8_MMA_AVAILABLE
47+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 0] = v.x;
48+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 1] = v.y;
49+
#else
50+
x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 0] = v.x;
51+
x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 1] = v.y;
52+
#endif // INT8_MMA_AVAILABLE
53+
}
54+
55+
#pragma unroll
56+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
57+
int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
58+
59+
if (need_check) {
60+
i = min(i, i_max);
61+
}
62+
63+
const float * dptr = (const float *)(x + i*stride);
64+
const float d = dptr[0];
65+
const block_iq1_kt * bxi = (const block_iq1_kt *)(dptr + 1) + kbx0;
66+
const int ls = iq4k_values[bxi->sh[threadIdx.x % 8] & 0xf];
67+
68+
#ifdef INT8_MMA_AVAILABLE
69+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * ls;
70+
#else
71+
x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * ls;
72+
#endif // INT8_MMA_AVAILABLE
73+
}
74+
}
75+
76+
template <int mmq_x, int mmq_y, bool need_check>
77+
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_IQ1_KT> {
78+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_kt<mmq_y, need_check>;
79+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
80+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
81+
};
82+
83+
DECL_MMQ_CASE(GGML_TYPE_IQ1_KT);
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../mmq_id_common.cuh"
4+
5+
DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
6+
DECL_MMQ_CASE(GGML_TYPE_IQ1_S_R4);

0 commit comments

Comments
 (0)