Skip to content

Commit 8e497e7

Browse files
ikawrakowIwan Kawrakow
andauthored
Fused matrix multiplications (CUDA and CPU) (ikawrakow#796)
* Quick attempt to fuse the Q, K, V GEMMs Doesn't do much on the CPU * Doesn't do much on the GPU either * Use llm_build_mul_mat_qkv * This is not needed * Revert timing on committed by mistake --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 0d1bbde commit 8e497e7

File tree

3 files changed

+244
-564
lines changed

3 files changed

+244
-564
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 73 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2143,7 +2143,62 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
21432143
}
21442144
}
21452145

2146-
static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2146+
static int ggml_cuda_mul_mat_q(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
2147+
const ggml_cgraph * cgraph, int node_n, bool is_gemv) {
2148+
2149+
auto stream = ctx.stream();
2150+
2151+
auto ne10_padded = GGML_PAD(src1->ne[0], MATRIX_ROW_PADDING);
2152+
auto nb10_padded = ne10_padded*sizeof(block_q8_1)/QK8_1;
2153+
auto quantized_size = nb10_padded*ggml_nrows(src1);
2154+
if (!is_gemv) {
2155+
quantized_size += get_mmq_x_max_host(ggml_cuda_info().devices[ctx.device].cc)*sizeof(block_q8_1_mmq);
2156+
}
2157+
ggml_cuda_pool_alloc<char> src1_quantized(ctx.pool(), quantized_size);
2158+
if (is_gemv) {
2159+
quantize_row_q8_1_cuda((const float *)src1->data, (void *)src1_quantized.get(), src1->ne[0], src1->ne[1], src1->ne[2], ne10_padded,
2160+
src0->type, stream);
2161+
CUDA_CHECK(cudaGetLastError());
2162+
2163+
ggml_cuda_op_mul_mat_vec_q(ctx, src0, src1, dst, (const char *)src0->data, nullptr, src1_quantized.get(), (float *)dst->data,
2164+
0, src0->ne[1], src1->ne[1], ne10_padded, stream);
2165+
CUDA_CHECK(cudaGetLastError());
2166+
} else {
2167+
quantize_mmq_q8_1_cuda((const float *)src1->data, src1_quantized.get(), src1->ne[0], src1->ne[1], 1, ne10_padded, src0->type, stream);
2168+
CUDA_CHECK(cudaGetLastError());
2169+
2170+
ggml_cuda_op_mul_mat_q(ctx, src0, src1, dst, (const char *)src0->data, nullptr, src1_quantized.get(), (float *)dst->data,
2171+
0, src0->ne[1], src1->ne[1], ne10_padded, stream);
2172+
CUDA_CHECK(cudaGetLastError());
2173+
}
2174+
2175+
if (!cgraph) return node_n;
2176+
2177+
while (node_n + 1 < cgraph->n_nodes) {
2178+
dst = cgraph->nodes[node_n+1];
2179+
if (ggml_is_empty(dst) || dst->op == GGML_OP_RESHAPE || dst->op == GGML_OP_TRANSPOSE || dst->op == GGML_OP_VIEW
2180+
|| dst->op == GGML_OP_PERMUTE || dst->op == GGML_OP_NONE) {
2181+
++node_n; continue;
2182+
}
2183+
if (dst->op != GGML_OP_MUL_MAT || dst->src[1] != src1 || !ggml_is_quantized(dst->src[0]->type)) break;
2184+
if (!is_gemv && mmq_get_q8_1_ds_layout(src0->type) != mmq_get_q8_1_ds_layout(dst->src[0]->type)) break;
2185+
if (is_gemv) {
2186+
ggml_cuda_op_mul_mat_vec_q(ctx, dst->src[0], src1, dst, (const char *)dst->src[0]->data, nullptr, src1_quantized.get(),
2187+
(float *)dst->data, 0, dst->src[0]->ne[1], src1->ne[1], ne10_padded, stream);
2188+
} else {
2189+
ggml_cuda_op_mul_mat_q(ctx, dst->src[0], src1, dst, (const char *)dst->src[0]->data, nullptr, src1_quantized.get(),
2190+
(float *)dst->data, 0, dst->src[0]->ne[1], src1->ne[1], ne10_padded, stream);
2191+
}
2192+
CUDA_CHECK(cudaGetLastError());
2193+
++node_n;
2194+
}
2195+
2196+
return node_n;
2197+
2198+
}
2199+
2200+
static int ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
2201+
const ggml_cgraph * cgraph, int node_n) {
21472202
const bool split = ggml_backend_buffer_is_cuda_split(src0->buffer);
21482203

21492204
// If src0 is a temporary compute buffer it may have some padding that needs to be cleared for mul_mat_vec_q or mul_mat_q.
@@ -2188,6 +2243,10 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
21882243
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
21892244
}
21902245

2246+
if (!split && (use_mul_mat_vec_q || use_mul_mat_q) && src1->ne[2]*src1->ne[3] == 1) {
2247+
return ggml_cuda_mul_mat_q(ctx, src0, src1, dst, cgraph, node_n, use_mul_mat_vec_q);
2248+
}
2249+
21912250
// debug helpers
21922251
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
21932252
//printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
@@ -2215,6 +2274,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
22152274
} else {
22162275
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr);
22172276
}
2277+
return node_n;
22182278
}
22192279

22202280
struct mmid_row_mapping {
@@ -2454,7 +2514,7 @@ static bool ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
24542514
src1_row.data = src1_original + i11*nb11 + i12*nb12;
24552515
dst_row.data = dst_original + i1*nb1 + i2*nb2;
24562516

2457-
ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
2517+
ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row, nullptr, 0);
24582518
}
24592519
}
24602520
} else {
@@ -2505,7 +2565,7 @@ static bool ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
25052565
dst_row.nb[2] = num_src1_rows*nb1;
25062566
dst_row.nb[3] = num_src1_rows*nb1;
25072567

2508-
ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
2568+
ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row, nullptr, 0);
25092569

25102570
{
25112571
dim3 block_dims(std::min((unsigned int)ne0, 768u));
@@ -2889,7 +2949,7 @@ static bool ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_te
28892949
ggml_cuda_op_mul_mat_q(ctx, &src0_1_row, &src1_row, &dst_row, (const char *)src0_1_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data,
28902950
0, src0_1_row.ne[1], num_src1_rows, src1_padded_num_cols, stream);
28912951
} else {
2892-
ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row);
2952+
ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row, nullptr, 0);
28932953
}
28942954
CUDA_CHECK(cudaGetLastError());
28952955

@@ -2906,7 +2966,7 @@ static bool ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_te
29062966
ggml_cuda_op_mul_mat_q(ctx, &src0_2_row, &src1_row, &dst_row, (const char *)src0_2_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data,
29072967
0, src0_2_row.ne[1], num_src1_rows, src1_padded_num_cols, stream);
29082968
} else {
2909-
ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row);
2969+
ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row, nullptr, 0);
29102970
}
29112971
CUDA_CHECK(cudaGetLastError());
29122972

@@ -2947,8 +3007,7 @@ static bool ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_te
29473007
(int)dst_row.ne[0], (int)dst_row.ne[1], (int)dst_row.ne[2], (int)dst_row.ne[3]);
29483008
first = false;
29493009
}
2950-
ggml_cuda_mul_mat(ctx, &final_src, &dst_row, &final_dst);
2951-
//ggml_cuda_mul_mat(ctx, next->src[0], &dst_row, &final_dst);
3010+
ggml_cuda_mul_mat(ctx, &final_src, &dst_row, &final_dst, nullptr, 0);
29523011
CUDA_CHECK(cudaGetLastError());
29533012

29543013
dim3 block_dims(std::min((unsigned int)next->ne[0], 768u));
@@ -3031,8 +3090,7 @@ static void ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
30313090

30323091
}
30333092

3034-
static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst, struct ggml_tensor * next,
3035-
const ggml_cgraph * cgraph, int & i) {
3093+
static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst, const ggml_cgraph * cgraph, int & i) {
30363094
// why is this here instead of mul_mat?
30373095
if (dst->src[0] != nullptr && ggml_backend_buffer_is_cuda_split(dst->src[0]->buffer)) {
30383096
ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device);
@@ -3042,6 +3100,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
30423100
int64_t tim1 = ggml_time_us();
30433101
#endif
30443102

3103+
auto next = i < cgraph->n_nodes - 1 ? cgraph->nodes[i+1] : nullptr;
3104+
30453105
switch (dst->op) {
30463106
case GGML_OP_REPEAT:
30473107
ggml_cuda_op_repeat(ctx, dst);
@@ -3112,7 +3172,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
31123172
ggml_cuda_op_hardswish(ctx, dst);
31133173
break;
31143174
default:
3115-
return false;
3175+
return -1;
31163176
}
31173177
break;
31183178
case GGML_OP_NORM:
@@ -3148,9 +3208,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
31483208
case GGML_OP_MUL_MAT:
31493209
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
31503210
GGML_CUDA_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]);
3151-
return false;
3211+
return -1;
31523212
} else {
3153-
ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst);
3213+
i = ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst, cgraph, i);
31543214
}
31553215
break;
31563216
case GGML_OP_MUL_MAT_ID:
@@ -3569,7 +3629,6 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
35693629

35703630
for (int i = 0; i < cgraph->n_nodes; i++) {
35713631
ggml_tensor * node = cgraph->nodes[i];
3572-
ggml_tensor * next = i < cgraph->n_nodes-1 ? cgraph->nodes[i+1] : nullptr;
35733632

35743633
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
35753634
continue;
@@ -3604,7 +3663,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
36043663
GGML_UNUSED(integrated);
36053664
#endif // NDEBUG
36063665

3607-
bool ok = ggml_cuda_compute_forward(*cuda_ctx, node, next, cgraph, i);
3666+
bool ok = ggml_cuda_compute_forward(*cuda_ctx, node, cgraph, i);
36083667
if (!ok) {
36093668
GGML_CUDA_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
36103669
}

0 commit comments

Comments
 (0)