Skip to content

Commit 699ad6a

Browse files
committed
CUDA: noncont MMVQ + batched bs1 MUL_MAT_ID
ggml-org/llama.cpp#13014
1 parent 317a116 commit 699ad6a

File tree

8 files changed

+944
-161
lines changed

8 files changed

+944
-161
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 102 additions & 106 deletions
Large diffs are not rendered by default.

ggml/src/ggml-cuda/mmv.cu

Lines changed: 336 additions & 0 deletions
Large diffs are not rendered by default.

ggml/src/ggml-cuda/mmv.cuh

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#include "common.cuh"
2+
3+
// maximum number of src0 rows with which to use mul_mat_vec over cuBLAS if FP16 tensor cores are available
4+
#define MMV_MAX_ROWS 512
5+
6+
void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
7+
8+
void ggml_cuda_op_mul_mat_vec(
9+
ggml_backend_cuda_context & ctx,
10+
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
11+
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
12+
const int64_t src1_padded_row_size, cudaStream_t stream);

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 464 additions & 32 deletions
Large diffs are not rendered by default.

ggml/src/ggml-cuda/mmvq.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99

1010
#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
1111

12+
void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
13+
const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
14+
1215
void ggml_cuda_op_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
1316
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
1417
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,

ggml/src/ggml-cuda/quantize.cu

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -259,45 +259,47 @@ static __global__ void quantize_mmq_q8_1_id(
259259
}
260260

261261
void quantize_row_q8_1_cuda(
262-
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
263-
const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
262+
const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
263+
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
264264

265-
GGML_ASSERT(kx0_padded % QK8_1 == 0);
265+
GGML_ASSERT(ne0 % QK8_1 == 0);
266266

267-
const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
268-
const dim3 num_blocks(block_num_x, kx1*channels, 1);
267+
const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
268+
const dim3 num_blocks(block_num_x, ne1, ne2*ne3);
269269
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
270-
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx0_padded);
271-
272-
GGML_UNUSED(type_x);
270+
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
271+
GGML_UNUSED(type_src0);
273272
}
274273

275274
void quantize_mmq_q8_1_cuda(
276-
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
277-
const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
275+
const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
276+
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
278277

279-
GGML_ASSERT(kx0_padded % (4*QK8_1) == 0);
278+
GGML_ASSERT(ne0 % (4*QK8_1) == 0);
280279

281-
const int64_t block_num_x = (kx0_padded + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
282-
const dim3 num_blocks(block_num_x, kx1, channels);
280+
const int64_t block_num_x = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
281+
const dim3 num_blocks(block_num_x, ne1, ne2*ne3);
283282
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
284-
switch (mmq_get_q8_1_ds_layout(type_x)) {
283+
switch (mmq_get_q8_1_ds_layout(type_src0)) {
285284
case MMQ_Q8_1_DS_LAYOUT_D4:
286285
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4>
287-
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
286+
<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, ne1, ne0);
288287
break;
289288
case MMQ_Q8_1_DS_LAYOUT_DS4:
290289
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_DS4>
291-
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
290+
<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, ne1, ne0);
292291
break;
293292
case MMQ_Q8_1_DS_LAYOUT_D2S6:
294293
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D2S6>
295-
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
294+
<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, ne1, ne0);
296295
break;
297296
default:
298297
GGML_ABORT("fatal error");
299298
break;
300299
}
300+
GGML_UNUSED(s01);
301+
GGML_UNUSED(s02);
302+
GGML_UNUSED(s03);
301303
}
302304

303305
void quantize_mmq_q8_1_id_cuda(

ggml/src/ggml-cuda/quantize.cuh

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,16 @@ static_assert(MATRIX_ROW_PADDING % CUDA_QUANTIZE_BLOCK_SIZE == 0, "Risk
1919
static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access.");
2020

2121
typedef void (*quantize_cuda_t)(
22-
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
23-
const ggml_type type_x, cudaStream_t stream);
22+
const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
23+
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream);
2424

2525
void quantize_row_q8_1_cuda(
26-
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
27-
const ggml_type type_x, cudaStream_t stream);
26+
const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
27+
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream);
2828

2929
void quantize_mmq_q8_1_cuda(
30-
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
31-
const ggml_type type_x, cudaStream_t stream);
30+
const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
31+
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream);
3232

3333
void quantize_mmq_q8_1_id_cuda(
3434
const float * x, void * vy, const char * row_mapping, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded,

ggml/src/ggml-cuda/vecdotq.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
// SPDX-License-Identifier: MIT
66
//
77

8+
#pragma once
9+
810
#include "common.cuh"
911
#include <cstdint>
1012

0 commit comments

Comments
 (0)