Skip to content

Commit 80a6cf6

Browse files
authored
vulkan: fuse mul_mat_id + mul (#17095)
* vulkan: fuse mul_mat_id + mul This comes up in qwen3 moe. * split mul_mat_id fusion tests into a separate class
1 parent 0750a59 commit 80a6cf6

File tree

3 files changed

+180
-51
lines changed

3 files changed

+180
-51
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -830,6 +830,7 @@ struct vk_mat_vec_push_constants {
830830
uint32_t batch_stride_b;
831831
uint32_t batch_stride_d;
832832
uint32_t enable_bias;
833+
uint32_t enable_scale;
833834
uint32_t ne02;
834835
uint32_t ne12;
835836
uint32_t broadcast2;
@@ -852,6 +853,7 @@ struct vk_mat_vec_id_push_constants {
852853
uint32_t batch_stride_b;
853854
uint32_t batch_stride_d;
854855
uint32_t enable_bias;
856+
uint32_t enable_scale;
855857
uint32_t nei0;
856858
uint32_t ne11;
857859
};
@@ -6863,7 +6865,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
68636865
// compute
68646866
const vk_mat_vec_push_constants pc = {
68656867
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
6866-
stride_batch_x, stride_batch_y, stride_batch_d, enable_bias,
6868+
stride_batch_x, stride_batch_y, stride_batch_d, enable_bias, 0,
68676869
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
68686870
};
68696871
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
@@ -7684,13 +7686,22 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
76847686
groups_x = CEIL_DIV(groups_x, groups_z);
76857687
}
76867688

7687-
uint32_t enable_bias = ctx->num_additional_fused_ops > 0;
7689+
uint32_t enable_bias = 0;
7690+
uint32_t enable_scale = 0;
7691+
if (ctx->num_additional_fused_ops > 0) {
7692+
if (cgraph->nodes[node_idx + 1]->op == GGML_OP_MUL) {
7693+
enable_scale = 1;
7694+
} else {
7695+
GGML_ASSERT(cgraph->nodes[node_idx + 1]->op == GGML_OP_ADD_ID);
7696+
enable_bias = 1;
7697+
}
7698+
}
76887699

76897700
vk_buffer d_B = d_D;
76907701
size_t b_buf_offset = 0;
76917702
uint64_t b_sz = 0;
76927703

7693-
if (enable_bias) {
7704+
if (enable_bias || enable_scale) {
76947705
const ggml_tensor * bias = cgraph->nodes[node_idx + 1]->src[1];
76957706

76967707
bool b_uma = false;
@@ -7712,7 +7723,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
77127723
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
77137724
(uint32_t)x_ne, stride_batch_y, (uint32_t)(ne20*ne21),
77147725

7715-
enable_bias,
7726+
enable_bias, enable_scale,
77167727

77177728
(uint32_t)nei0, (uint32_t)ne11,
77187729
};
@@ -12490,6 +12501,40 @@ static bool ggml_vk_can_fuse(const ggml_backend_vk_context * ctx, const struct g
1249012501
}
1249112502
}
1249212503

12504+
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_MUL_MAT_ID && ops.begin()[1] == GGML_OP_MUL) {
12505+
// additional constraints specific to this fusion
12506+
const ggml_tensor *mmid = cgraph->nodes[node_idx];
12507+
const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
12508+
const ggml_tensor *scale = mul->src[1];
12509+
12510+
if (mmid != mul->src[0]) {
12511+
return false;
12512+
}
12513+
// mat-vec only
12514+
if (!ggml_vk_use_mul_mat_vec_id(cgraph, node_idx)) {
12515+
return false;
12516+
}
12517+
// shaders assume the types match
12518+
if (mmid->type != scale->type) {
12519+
return false;
12520+
}
12521+
// shaders assume the bias is contiguous
12522+
if (!ggml_is_contiguous(scale)) {
12523+
return false;
12524+
}
12525+
// unaligned bias isn't handled
12526+
if (get_misalign_bytes(ctx, scale) != 0) {
12527+
return false;
12528+
}
12529+
// shader only indexes by expert index
12530+
if (scale->ne[0] != 1 ||
12531+
scale->ne[1] != mul->ne[1] ||
12532+
scale->ne[2] != 1 ||
12533+
scale->ne[3] != 1) {
12534+
return false;
12535+
}
12536+
}
12537+
1249312538
return true;
1249412539
}
1249512540

@@ -12798,6 +12843,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1279812843
ctx->num_additional_fused_ops = 1;
1279912844
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) {
1280012845
ctx->num_additional_fused_ops = 1;
12846+
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_MUL })) {
12847+
ctx->num_additional_fused_ops = 1;
1280112848
} else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 4 }) &&
1280212849
ggml_check_edges(cgraph, i, rms_norm_mul_rope_view_set_rows_edges) &&
1280312850
ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i) &&
@@ -13033,7 +13080,8 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
1303313080
is_src_of(graph->nodes[j], graph->nodes[c]) &&
1303413081
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_RMS_NORM && graph->nodes[j]->op == GGML_OP_MUL) &&
1303513082
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT && graph->nodes[j]->op == GGML_OP_ADD) &&
13036-
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID)) {
13083+
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID) &&
13084+
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL)) {
1303713085
ok = false;
1303813086
break;
1303913087
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ layout (push_constant) uniform parameter
4949
uint batch_stride_d;
5050

5151
uint enable_bias;
52+
uint enable_scale;
5253

5354
#ifdef MUL_MAT_ID
5455
uint nei0;
@@ -129,6 +130,12 @@ void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t
129130
temp[j][n] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
130131
#endif
131132
}
133+
#ifdef MUL_MAT_ID
134+
if (p.enable_scale != 0) {
135+
const uint expert_idx = gl_GlobalInvocationID.y;
136+
temp[j][n] *= FLOAT_TYPE(data_bias[expert_idx]);
137+
}
138+
#endif
132139
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
133140
}
134141
}
@@ -171,6 +178,12 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
171178
temp[j][n] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
172179
#endif
173180
}
181+
#ifdef MUL_MAT_ID
182+
if (p.enable_scale != 0) {
183+
const uint expert_idx = gl_GlobalInvocationID.y;
184+
temp[j][n] *= FLOAT_TYPE(data_bias[expert_idx]);
185+
}
186+
#endif
174187
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
175188
}
176189
}
@@ -203,6 +216,12 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
203216
tmpsh[j][n][0] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
204217
#endif
205218
}
219+
#ifdef MUL_MAT_ID
220+
if (p.enable_scale != 0) {
221+
const uint expert_idx = gl_GlobalInvocationID.y;
222+
tmpsh[j][n][0] *= FLOAT_TYPE(data_bias[expert_idx]);
223+
}
224+
#endif
206225
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]);
207226
}
208227
}

0 commit comments

Comments
 (0)