@@ -517,6 +517,8 @@ struct vk_device_struct {
517517
518518 ggml_backend_buffer_type buffer_type;
519519
520+ bool disable_fusion;
521+
520522#ifdef GGML_VULKAN_MEMORY_DEBUG
521523 std::unique_ptr<vk_memory_logger> memory_logger;
522524#endif
@@ -652,6 +654,7 @@ struct vk_flash_attn_push_constants {
652654 uint32_t nev3;
653655 uint32_t nem1;
654656 uint32_t nem2;
657+ uint32_t nem3;
655658
656659 uint32_t nb01;
657660 uint32_t nb02;
@@ -667,8 +670,7 @@ struct vk_flash_attn_push_constants {
667670 float max_bias;
668671 float logit_softcap;
669672
670- uint32_t mask;
671- uint32_t n_head_log2;
673+ uint32_t mask_n_head_log2;
672674 float m0;
673675 float m1;
674676
@@ -1107,8 +1109,8 @@ static size_t vk_skip_checks;
11071109static size_t vk_output_tensor;
11081110
11091111static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name);
1110- static void ggml_vk_check_results_0(ggml_tensor * tensor );
1111- static void ggml_vk_check_results_1(ggml_tensor * tensor );
1112+ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx );
1113+ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx );
11121114#endif
11131115
11141116typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
@@ -3531,6 +3533,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
35313533
35323534 device->idx = idx;
35333535
3536+ device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr;
3537+
35343538 return device;
35353539 }
35363540
@@ -6135,6 +6139,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
61356139
61366140 const uint32_t nem1 = mask ? mask->ne[1] : 0;
61376141 const uint32_t nem2 = mask ? mask->ne[2] : 0;
6142+ const uint32_t nem3 = mask ? mask->ne[3] : 0;
61386143
61396144 const uint32_t HSK = nek0;
61406145 const uint32_t HSV = nev0;
@@ -6202,7 +6207,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
62026207 }
62036208
62046209 if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
6205- qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 = = 1) {
6210+ qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 < = 1) {
62066211 // grouped query attention - make the N dimension equal to gqa_ratio, reduce
62076212 // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
62086213 // and change addressing calculations to index Q's dimension 2.
@@ -6372,17 +6377,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
63726377 }
63736378 }
63746379
6380+ uint32_t mask_n_head_log2 = ((mask != nullptr) << 16) | n_head_log2;
6381+
63756382 const vk_flash_attn_push_constants pc = { N, KV,
63766383 (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
63776384 (uint32_t)neq2, (uint32_t)neq3,
63786385 (uint32_t)nek2, (uint32_t)nek3,
63796386 (uint32_t)nev2, (uint32_t)nev3,
6380- nem1, nem2,
6387+ nem1, nem2, nem3,
63816388 q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
63826389 k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
63836390 v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
63846391 scale, max_bias, logit_softcap,
6385- mask != nullptr, n_head_log2 , m0, m1,
6392+ mask_n_head_log2 , m0, m1,
63866393 gqa_ratio, split_kv, split_k };
63876394
63886395 ggml_vk_sync_buffers(subctx);
@@ -7675,8 +7682,7 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
76757682 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
76767683}
76777684
7678- static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7679- float * op_params = (float *)dst->op_params;
7685+ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params, bool dryrun = false) {
76807686 const uint32_t src0_type_size = ggml_type_size(src0->type);
76817687 const uint32_t src1_type_size = ggml_type_size(src1->type);
76827688 const uint32_t dst_type_size = ggml_type_size(dst->type);
@@ -8906,7 +8912,7 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
89068912 }
89078913}
89088914
8909- static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
8915+ static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
89108916
89118917// Returns true if node has enqueued work into the queue, false otherwise
89128918// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
@@ -9167,9 +9173,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
91679173 // fused rms_norm + mul
91689174 ggml_tensor *mul = cgraph->nodes[node_idx + 1];
91699175 ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0];
9170- ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, dryrun);
9176+ ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, (float *)node->op_params, dryrun);
91719177 } else {
9172- ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, dryrun);
9178+ ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, (float *)node->op_params, dryrun);
91739179 }
91749180 break;
91759181 case GGML_OP_RMS_NORM_BACK:
@@ -9329,7 +9335,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
93299335
93309336 ctx->compute_ctx.reset();
93319337
9332- bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false, almost_ready);
9338+ bool ok = ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, false, almost_ready);
93339339 if (!ok) {
93349340 if (node->op == GGML_OP_UNARY) {
93359341 std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
@@ -9344,7 +9350,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
93449350 return true;
93459351}
93469352
9347- static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) {
9353+ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) {
9354+ GGML_UNUSED(cgraph);
93489355 ggml_backend_buffer * buf = nullptr;
93499356
93509357 switch (tensor->op) {
@@ -9454,7 +9461,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
94549461 // Only run if ctx hasn't been submitted yet
94559462 if (!subctx->seqs.empty()) {
94569463#ifdef GGML_VULKAN_CHECK_RESULTS
9457- ggml_vk_check_results_0(tensor );
9464+ ggml_vk_check_results_0(ctx, cgraph, tensor_idx );
94589465 use_fence = true;
94599466#endif
94609467
@@ -9474,7 +9481,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
94749481 ggml_vk_wait_for_fence(ctx);
94759482 }
94769483#ifdef GGML_VULKAN_CHECK_RESULTS
9477- ggml_vk_check_results_1(tensor );
9484+ ggml_vk_check_results_1(ctx, cgraph, tensor_idx );
94789485#endif
94799486 }
94809487
@@ -9921,6 +9928,37 @@ static bool ggml_vk_is_empty(ggml_tensor * node) {
99219928 return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
99229929}
99239930
9931+ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
9932+ if (!ggml_can_fuse(cgraph, node_idx, ops)) {
9933+ return false;
9934+ }
9935+
9936+ if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
9937+ // additional constraints specific to this fusion
9938+ const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
9939+ const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
9940+
9941+ GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
9942+ GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
9943+ // rms_norm only supports f32
9944+ if (mul->src[0]->type != GGML_TYPE_F32 ||
9945+ mul->src[1]->type != GGML_TYPE_F32 ||
9946+ mul->type != GGML_TYPE_F32) {
9947+ return false;
9948+ }
9949+ // if rms_norm is the B operand, then we don't handle broadcast
9950+ if (rms_norm == mul->src[1] &&
9951+ mul->src[0]->ne[1] != rms_norm->ne[1]) {
9952+ return false;
9953+ }
9954+ // rms_norm shader assumes contiguous rows
9955+ if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
9956+ return false;
9957+ }
9958+ }
9959+ return true;
9960+ }
9961+
99249962static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
99259963 VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
99269964 ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
@@ -9934,7 +9972,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
99349972
99359973 uint64_t total_mat_mul_bytes = 0;
99369974 for (int i = 0; i < cgraph->n_nodes; i++) {
9937- if (ggml_can_fuse (cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
9975+ if (!ctx->device->disable_fusion && ggml_vk_can_fuse (cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
99389976 ctx->num_additional_fused_ops = 1;
99399977 }
99409978 ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
@@ -10004,7 +10042,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1000410042 mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
1000510043 }
1000610044
10007- if (ggml_can_fuse (cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
10045+ if (!ctx->device->disable_fusion && ggml_vk_can_fuse (cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
1000810046 ctx->num_additional_fused_ops = 1;
1000910047 }
1001010048
@@ -10327,12 +10365,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1032710365 if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
1032810366 return false;
1032910367 }
10330- // TODO: support broadcast
10331- // note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14449, but
10332- // the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
10333- if (op->src[0]->ne[3] != 1 || (op->src[3] && op->src[3]->ne[2] != 1)) {
10334- return false;
10335- }
1033610368 // It's straightforward to support different K/V dequant, but would
1033710369 // significantly increase the number of pipelines
1033810370 if (op->src[1]->type != op->src[2]->type) {
@@ -10787,11 +10819,21 @@ void * comp_result;
1078710819size_t comp_size;
1078810820size_t comp_nb[GGML_MAX_DIMS];
1078910821size_t check_counter = 0;
10790- static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10822+ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
10823+ ggml_tensor * tensor = cgraph->nodes[tensor_idx];
1079110824 if (tensor->op == GGML_OP_TRANSPOSE) {
1079210825 return;
1079310826 }
1079410827
10828+ bool fused_rms_norm_mul = false;
10829+ int rms_norm_idx = -1;
10830+ if (ctx->num_additional_fused_ops == 1 &&
10831+ tensor->op == GGML_OP_RMS_NORM &&
10832+ cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
10833+ fused_rms_norm_mul = true;
10834+ tensor = cgraph->nodes[tensor_idx + 1];
10835+ }
10836+
1079510837 check_counter++;
1079610838 if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
1079710839 return;
@@ -10819,6 +10861,15 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
1081910861
1082010862 for (int i = 0; i < 6; i++) {
1082110863 ggml_tensor * srci = tensor->src[i];
10864+ if (fused_rms_norm_mul) {
10865+ rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1;
10866+ ggml_tensor *rms_norm = tensor->src[rms_norm_idx];
10867+ switch (i) {
10868+ case 0: srci = rms_norm->src[0]; break;
10869+ case 1: srci = tensor->src[1 - rms_norm_idx]; break;
10870+ default: continue;
10871+ }
10872+ }
1082210873 if (srci == nullptr) {
1082310874 continue;
1082410875 }
@@ -10876,7 +10927,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
1087610927 } else if (tensor->op == GGML_OP_SUB) {
1087710928 tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]);
1087810929 } else if (tensor->op == GGML_OP_MUL) {
10879- tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
10930+ if (fused_rms_norm_mul) {
10931+ tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->src[rms_norm_idx]->op_params);
10932+ tensor_clone = ggml_mul(ggml_ctx, tensor_clone, src_clone[1 - rms_norm_idx]);
10933+ } else {
10934+ tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
10935+ }
1088010936 } else if (tensor->op == GGML_OP_DIV) {
1088110937 tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]);
1088210938 } else if (tensor->op == GGML_OP_CONCAT) {
@@ -11067,10 +11123,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
1106711123 GGML_ABORT("fatal error");
1106811124 }
1106911125
11070- ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx);
11071- ggml_build_forward_expand(cgraph , tensor_clone);
11126+ ggml_cgraph * cgraph_cpu = ggml_new_graph(ggml_ctx);
11127+ ggml_build_forward_expand(cgraph_cpu , tensor_clone);
1107211128
11073- ggml_graph_compute_with_ctx(ggml_ctx, cgraph , 8);
11129+ ggml_graph_compute_with_ctx(ggml_ctx, cgraph_cpu , 8);
1107411130
1107511131 if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
1107611132 ggml_vk_print_tensor(tensor_clone, "tensor_clone");
@@ -11093,10 +11149,19 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
1109311149 VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")");
1109411150}
1109511151
11096- static void ggml_vk_check_results_1(ggml_tensor * tensor) {
11152+ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
11153+ ggml_tensor * tensor = cgraph->nodes[tensor_idx];
1109711154 if (tensor->op == GGML_OP_TRANSPOSE) {
1109811155 return;
1109911156 }
11157+ bool fused_rms_norm_mul = false;
11158+ if (ctx->num_additional_fused_ops == 1 &&
11159+ tensor->op == GGML_OP_RMS_NORM &&
11160+ cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
11161+ fused_rms_norm_mul = true;
11162+ tensor = cgraph->nodes[tensor_idx + 1];
11163+ }
11164+
1110011165 if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
1110111166 return;
1110211167 }
0 commit comments