Skip to content

Commit 4c0b660

Browse files
ikawrakowIwan Kawrakow
andauthored
CUDA: small PP performance improvement for MoE models (#589)
* Trying to implement quantized fmoe - not working yet * This works, but is slower than the non-working version * quantize_mmq_q8_1_id * Minor --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 6f3a3ba commit 4c0b660

File tree

3 files changed

+161
-8
lines changed

3 files changed

+161
-8
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2186,6 +2186,7 @@ struct mmid_row_mapping {
21862186
int32_t i2;
21872187
};
21882188

2189+
template <typename data_t = float>
21892190
static __global__ void k_copy_src_to_contiguous(const char * __restrict__ src_original, char * __restrict__ src_contiguous,
21902191
const mmid_row_mapping * __restrict__ row_mapping,
21912192
int64_t ne10, int64_t ne11, size_t nb11, size_t nb12) {
@@ -2194,8 +2195,8 @@ static __global__ void k_copy_src_to_contiguous(const char * __restrict__ src_or
21942195
const int32_t i11 = row_mapping[i].i1 % ne11;
21952196
const int32_t i12 = row_mapping[i].i2;
21962197

2197-
float * src_row_contiguous = (float *)(src_contiguous + i*nb11);
2198-
const float * src_row_original = (const float *)(src_original + i11*nb11 + i12*nb12);
2198+
data_t * src_row_contiguous = (data_t *)(src_contiguous + i*nb11);
2199+
const data_t * src_row_original = (const data_t *)(src_original + i11*nb11 + i12*nb12);
21992200

22002201
for (int j = threadIdx.x; j < ne10; j += blockDim.x) {
22012202
src_row_contiguous[j] = src_row_original[j];
@@ -2673,6 +2674,17 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
26732674
}
26742675
}
26752676
} else {
2677+
//printf("ne10 = %ld, ne11 = %ld, ne12 = %ld, nb10 = %zu nb11 = %zu nb12 = %zu\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->nb[0], src1->nb[1], src1->nb[2]);
2678+
ggml_cuda_pool_alloc<char> src1_quantized(ctx.pool());
2679+
bool use_quantized_src1 = false;
2680+
int64_t src1_padded_num_cols = 0, src1_padded_row_size = 0, src1_quantized_size = 0;
2681+
if (ggml_is_quantized(src0_1->type) && src0_1->type == src0_2->type && src1->ne[1] == 1 && src1->ne[3] == 1) {
2682+
src1_padded_num_cols = GGML_PAD(src1->ne[0], MATRIX_ROW_PADDING);
2683+
src1_padded_row_size = src1_padded_num_cols/ggml_blck_size(GGML_TYPE_Q8_1)*ggml_type_size(GGML_TYPE_Q8_1);
2684+
src1_quantized_size = src1_padded_row_size*src1->ne[2] + get_mmq_x_max_host(ggml_cuda_info().devices[ctx.device].cc)*sizeof(block_q8_1_mmq);
2685+
src1_quantized.alloc(src1_quantized_size);
2686+
use_quantized_src1 = true;
2687+
}
26762688
ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
26772689
ggml_cuda_pool_alloc<char> dst_up_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
26782690
ggml_cuda_pool_alloc<char> dst_gate_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
@@ -2704,7 +2716,13 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
27042716
if (num_src1_rows == 0) continue;
27052717
size_t mapping_offset = cum_moe_counts[i02];
27062718

2707-
{
2719+
if (use_quantized_src1) {
2720+
quantize_mmq_q8_1_id_cuda((const float *)src1->data, src1_quantized.get(), (const char *)(dev_row_mapping.get() + mapping_offset),
2721+
src1->ne[0], num_src1_rows, src1_padded_num_cols, src0_1->type, stream);
2722+
CUDA_CHECK(cudaGetLastError());
2723+
src1_row.data = src1_quantized.get();
2724+
}
2725+
else {
27082726
dim3 block_dims(std::min((unsigned int)ne10, 768u));
27092727
dim3 grid_dims(num_src1_rows);
27102728
k_copy_src_to_contiguous<<<grid_dims, block_dims, 0, stream>>>(
@@ -2719,21 +2737,31 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
27192737
GGML_ASSERT(nb1 == sizeof(float)*ne0);
27202738

27212739
src1_row.ne[1] = num_src1_rows;
2722-
src1_row.nb[1] = nb11;
2723-
src1_row.nb[2] = num_src1_rows*nb11;
2724-
src1_row.nb[3] = num_src1_rows*nb11;
2740+
src1_row.nb[1] = use_quantized_src1 ? src1_padded_row_size : nb11;
2741+
src1_row.nb[2] = num_src1_rows*src1_row.nb[1];
2742+
src1_row.nb[3] = num_src1_rows*src1_row.nb[1];
27252743

27262744
dst_row.ne[1] = num_src1_rows;
27272745
dst_row.nb[1] = nb1;
27282746
dst_row.nb[2] = num_src1_rows*nb1;
27292747
dst_row.nb[3] = num_src1_rows*nb1;
27302748

27312749
dst_row.data = dst_up_contiguous.get();
2732-
ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row);
2750+
if (use_quantized_src1) {
2751+
ggml_cuda_op_mul_mat_q(ctx, &src0_1_row, &src1_row, &dst_row, (const char *)src0_1_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data,
2752+
0, src0_1_row.ne[1], num_src1_rows, src1_padded_num_cols, stream);
2753+
} else {
2754+
ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row);
2755+
}
27332756
CUDA_CHECK(cudaGetLastError());
27342757

27352758
dst_row.data = dst_gate_contiguous.get();
2736-
ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row);
2759+
if (use_quantized_src1) {
2760+
ggml_cuda_op_mul_mat_q(ctx, &src0_2_row, &src1_row, &dst_row, (const char *)src0_2_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data,
2761+
0, src0_2_row.ne[1], num_src1_rows, src1_padded_num_cols, stream);
2762+
} else {
2763+
ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row);
2764+
}
27372765
CUDA_CHECK(cudaGetLastError());
27382766

27392767
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row),

ggml/src/ggml-cuda/quantize.cu

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,98 @@ static __global__ void quantize_mmq_q8_1(
166166
}
167167
}
168168

169+
struct mmid_row_mapping {
170+
int32_t i1;
171+
int32_t i2;
172+
};
173+
174+
template <mmq_q8_1_ds_layout ds_layout>
175+
static __global__ void quantize_mmq_q8_1_id(
176+
const float * __restrict__ x, void * __restrict__ vy, const char * row_mapping, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) {
177+
178+
constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
179+
constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
180+
181+
const int64_t ix0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4;
182+
183+
if (ix0 >= kx0_padded) {
184+
return;
185+
}
186+
187+
const float4 * x4 = (const float4 *) x;
188+
189+
const mmid_row_mapping * mapping = (const mmid_row_mapping *)row_mapping;
190+
const int64_t ii = mapping[blockIdx.y].i2;
191+
192+
block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
193+
194+
const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel
195+
const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel
196+
const int64_t iqs = ix0 % (4*QK8_1); // quant index in block
197+
198+
// Load 4 floats per thread and calculate max. abs. value between them:
199+
const float4 xi = ix0 < kx0 ? x4[(ii*kx0 + ix0)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f);
200+
float amax = fabsf(xi.x);
201+
amax = fmaxf(amax, fabsf(xi.y));
202+
amax = fmaxf(amax, fabsf(xi.z));
203+
amax = fmaxf(amax, fabsf(xi.w));
204+
205+
// Exchange max. abs. value between vals_per_scale/4 threads.
206+
#pragma unroll
207+
for (int mask = vals_per_scale/8; mask > 0; mask >>= 1) {
208+
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
209+
}
210+
211+
float sum;
212+
if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) {
213+
sum = xi.x + xi.y + xi.z + xi.w;
214+
215+
// Exchange calculate sum across vals_per_sum/4 threads.
216+
#pragma unroll
217+
for (int mask = vals_per_sum/8; mask > 0; mask >>= 1) {
218+
sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, WARP_SIZE);
219+
}
220+
}
221+
222+
const float d = amax/127.f;
223+
const float d_inv = d > 0 ? 1/d : 0.f;
224+
char4 q;
225+
q.x = roundf(xi.x*d_inv);
226+
q.y = roundf(xi.y*d_inv);
227+
q.z = roundf(xi.z*d_inv);
228+
q.w = roundf(xi.w*d_inv);
229+
230+
// Write back 4 int8 values as a single 32 bit value for better memroy bandwidth:
231+
char4 * yqs4 = (char4 *) y[ib].qs;
232+
yqs4[iqs/4] = q;
233+
234+
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6) {
235+
if (iqs % 16 != 0 || iqs >= 96) {
236+
return;
237+
}
238+
239+
y[ib].d2s6[2 + iqs/16] = sum;
240+
241+
if (iqs % 64 != 0) {
242+
return;
243+
}
244+
245+
y[ib].d2s6[iqs/64] = d;
246+
247+
return;
248+
}
249+
250+
if (iqs % 32 != 0) {
251+
return;
252+
}
253+
254+
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_DS4) {
255+
y[ib].ds4[iqs/32] = make_half2(d, sum);
256+
} else {
257+
y[ib].d4[iqs/32] = d;
258+
}
259+
}
260+
169261
void quantize_row_q8_1_cuda(
170262
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
171263
const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
@@ -208,6 +300,35 @@ void quantize_mmq_q8_1_cuda(
208300
}
209301
}
210302

303+
void quantize_mmq_q8_1_id_cuda(
304+
const float * x, void * vy, const char * row_mapping, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded,
305+
const ggml_type type_x, cudaStream_t stream) {
306+
307+
GGML_ASSERT(kx0_padded % (4*QK8_1) == 0);
308+
309+
const int64_t block_num_x = (kx0_padded + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
310+
const dim3 num_blocks(block_num_x, kx1, 1);
311+
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
312+
switch (mmq_get_q8_1_ds_layout(type_x)) {
313+
case MMQ_Q8_1_DS_LAYOUT_D4:
314+
quantize_mmq_q8_1_id<MMQ_Q8_1_DS_LAYOUT_D4>
315+
<<<num_blocks, block_size, 0, stream>>>(x, vy, row_mapping, kx0, kx1, kx0_padded);
316+
break;
317+
case MMQ_Q8_1_DS_LAYOUT_DS4:
318+
quantize_mmq_q8_1_id<MMQ_Q8_1_DS_LAYOUT_DS4>
319+
<<<num_blocks, block_size, 0, stream>>>(x, vy, row_mapping, kx0, kx1, kx0_padded);
320+
break;
321+
case MMQ_Q8_1_DS_LAYOUT_D2S6:
322+
quantize_mmq_q8_1_id<MMQ_Q8_1_DS_LAYOUT_D2S6>
323+
<<<num_blocks, block_size, 0, stream>>>(x, vy, row_mapping, kx0, kx1, kx0_padded);
324+
break;
325+
default:
326+
GGML_ABORT("fatal error");
327+
break;
328+
}
329+
}
330+
331+
211332
void quantize_tensor_q8_1_cuda(const struct ggml_tensor * src, void * vy, const enum ggml_type type, cudaStream_t stream) {
212333
GGML_ASSERT(src->ne[1] == 1 && src->ne[3] == 1);
213334
GGML_ASSERT(src->type == GGML_TYPE_F32);

ggml/src/ggml-cuda/quantize.cuh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,9 @@ void quantize_mmq_q8_1_cuda(
3030
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
3131
const ggml_type type_x, cudaStream_t stream);
3232

33+
void quantize_mmq_q8_1_id_cuda(
34+
const float * x, void * vy, const char * row_mapping, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded,
35+
const ggml_type type_x, cudaStream_t stream);
36+
3337
// For now only applicable for tensors with ne[1] = 1, ne[3] = 1, and useful if ne[2] > 1
3438
void quantize_tensor_q8_1_cuda(const struct ggml_tensor * src, void * vy, const enum ggml_type type, cudaStream_t stream);

0 commit comments

Comments
 (0)