@@ -830,6 +830,7 @@ struct vk_mat_vec_push_constants {
830830 uint32_t batch_stride_b;
831831 uint32_t batch_stride_d;
832832 uint32_t enable_bias;
833+ uint32_t enable_scale;
833834 uint32_t ne02;
834835 uint32_t ne12;
835836 uint32_t broadcast2;
@@ -852,6 +853,7 @@ struct vk_mat_vec_id_push_constants {
852853 uint32_t batch_stride_b;
853854 uint32_t batch_stride_d;
854855 uint32_t enable_bias;
856+ uint32_t enable_scale;
855857 uint32_t nei0;
856858 uint32_t ne11;
857859};
@@ -6863,7 +6865,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
68636865 // compute
68646866 const vk_mat_vec_push_constants pc = {
68656867 (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
6866- stride_batch_x, stride_batch_y, stride_batch_d, enable_bias,
6868+ stride_batch_x, stride_batch_y, stride_batch_d, enable_bias, 0,
68676869 (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
68686870 };
68696871 ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
@@ -7684,13 +7686,22 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
76847686 groups_x = CEIL_DIV(groups_x, groups_z);
76857687 }
76867688
7687- uint32_t enable_bias = ctx->num_additional_fused_ops > 0;
7689+ uint32_t enable_bias = 0;
7690+ uint32_t enable_scale = 0;
7691+ if (ctx->num_additional_fused_ops > 0) {
7692+ if (cgraph->nodes[node_idx + 1]->op == GGML_OP_MUL) {
7693+ enable_scale = 1;
7694+ } else {
7695+ GGML_ASSERT(cgraph->nodes[node_idx + 1]->op == GGML_OP_ADD_ID);
7696+ enable_bias = 1;
7697+ }
7698+ }
76887699
76897700 vk_buffer d_B = d_D;
76907701 size_t b_buf_offset = 0;
76917702 uint64_t b_sz = 0;
76927703
7693- if (enable_bias) {
7704+ if (enable_bias || enable_scale ) {
76947705 const ggml_tensor * bias = cgraph->nodes[node_idx + 1]->src[1];
76957706
76967707 bool b_uma = false;
@@ -7712,7 +7723,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
77127723 (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
77137724 (uint32_t)x_ne, stride_batch_y, (uint32_t)(ne20*ne21),
77147725
7715- enable_bias,
7726+ enable_bias, enable_scale,
77167727
77177728 (uint32_t)nei0, (uint32_t)ne11,
77187729 };
@@ -12490,6 +12501,40 @@ static bool ggml_vk_can_fuse(const ggml_backend_vk_context * ctx, const struct g
1249012501 }
1249112502 }
1249212503
12504+ if (ops.size() == 2 && ops.begin()[0] == GGML_OP_MUL_MAT_ID && ops.begin()[1] == GGML_OP_MUL) {
12505+ // additional constraints specific to this fusion
12506+ const ggml_tensor *mmid = cgraph->nodes[node_idx];
12507+ const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
12508+ const ggml_tensor *scale = mul->src[1];
12509+
12510+ if (mmid != mul->src[0]) {
12511+ return false;
12512+ }
12513+ // mat-vec only
12514+ if (!ggml_vk_use_mul_mat_vec_id(cgraph, node_idx)) {
12515+ return false;
12516+ }
12517+ // shaders assume the types match
12518+ if (mmid->type != scale->type) {
12519+ return false;
12520+ }
12521+ // shaders assume the bias is contiguous
12522+ if (!ggml_is_contiguous(scale)) {
12523+ return false;
12524+ }
12525+ // unaligned bias isn't handled
12526+ if (get_misalign_bytes(ctx, scale) != 0) {
12527+ return false;
12528+ }
12529+ // shader only indexes by expert index
12530+ if (scale->ne[0] != 1 ||
12531+ scale->ne[1] != mul->ne[1] ||
12532+ scale->ne[2] != 1 ||
12533+ scale->ne[3] != 1) {
12534+ return false;
12535+ }
12536+ }
12537+
1249312538 return true;
1249412539}
1249512540
@@ -12798,6 +12843,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1279812843 ctx->num_additional_fused_ops = 1;
1279912844 } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) {
1280012845 ctx->num_additional_fused_ops = 1;
12846+ } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_MUL })) {
12847+ ctx->num_additional_fused_ops = 1;
1280112848 } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 4 }) &&
1280212849 ggml_check_edges(cgraph, i, rms_norm_mul_rope_view_set_rows_edges) &&
1280312850 ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i) &&
@@ -13033,7 +13080,8 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
1303313080 is_src_of(graph->nodes[j], graph->nodes[c]) &&
1303413081 !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_RMS_NORM && graph->nodes[j]->op == GGML_OP_MUL) &&
1303513082 !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT && graph->nodes[j]->op == GGML_OP_ADD) &&
13036- !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID)) {
13083+ !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID) &&
13084+ !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL)) {
1303713085 ok = false;
1303813086 break;
1303913087 }
0 commit comments