@@ -9697,32 +9697,6 @@ static bool ggml_vk_is_empty(ggml_tensor * node) {
96979697 return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
96989698}
96999699
9700- // Returns true if nodes [i, i+1] are fusable RMS_NORM + MUL.
9701- static bool ggml_can_fuse_rms_norm_mul(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int i) {
9702- ggml_tensor *norm = cgraph->nodes[i];
9703-
9704- if (norm->op != GGML_OP_RMS_NORM) {
9705- return false;
9706- }
9707-
9708- if (!ggml_can_fuse_node(norm, 1)) {
9709- return false;
9710- }
9711-
9712- if (i + 1 >= cgraph->n_nodes) {
9713- return false;
9714- }
9715- ggml_tensor *mul = cgraph->nodes[i + 1];
9716- if (mul->op != GGML_OP_MUL || mul->src[0] != norm) {
9717- return false;
9718- }
9719-
9720- // Since norm is the first operand of mul, it must be the same shape
9721- GGML_ASSERT(ggml_are_same_shape(mul, norm));
9722-
9723- return true;
9724- }
9725-
97269700static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
97279701 VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
97289702 ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
@@ -9736,7 +9710,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
97369710
97379711 uint64_t total_mat_mul_bytes = 0;
97389712 for (int i = 0; i < cgraph->n_nodes; i++) {
9739- if (ggml_can_fuse_rms_norm_mul(ctx, cgraph, i)) {
9713+ if (ggml_can_fuse( cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL } )) {
97409714 ctx->num_additional_fused_ops = 1;
97419715 }
97429716 ggml_vk_build_graph(ctx, cgraph, cgraph->nodes[i], i, nullptr, 0, true, false, false, false);
@@ -9806,7 +9780,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
98069780 mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
98079781 }
98089782
9809- if (ggml_can_fuse_rms_norm_mul(ctx, cgraph, i)) {
9783+ if (ggml_can_fuse( cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL } )) {
98109784 ctx->num_additional_fused_ops = 1;
98119785 }
98129786
0 commit comments