Skip to content

Commit b046c73

Browse files
committed
handle clamp added in #16655
1 parent 180eef4 commit b046c73

File tree

3 files changed

+38
-19
lines changed

3 files changed

+38
-19
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -387,22 +387,24 @@ static constexpr uint32_t num_topk_moe_pipelines = 10;
387387

388388
static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
389389
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
390-
GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
390+
GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
391+
GGML_OP_RESHAPE };
391392
static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
392393
GGML_OP_VIEW, GGML_OP_GET_ROWS };
393394
static constexpr std::initializer_list<ggml_op> topk_moe_late_softmax { GGML_OP_ARGSORT, GGML_OP_VIEW,
394395
GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
395396
GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
396397

397-
//node #963 ( SOFT_MAX): ffn_moe_probs-15 ( 64K) [Vulka ] use=2: ffn_moe_logits-15 ( 64K) [Vulka ]
398-
//node #964 ( RESHAPE): ffn_moe_probs-15 (re ( 64K) [Vulka ] use=1: ffn_moe_probs-15 ( 64K) [Vulka ]
399-
//node #965 ( ARGSORT): ffn_moe_argsort-15 ( 64K) [Vulka ] use=1: ffn_moe_probs-15 ( 64K) [Vulka ]
400-
//node #966 ( VIEW): ffn_moe_topk-15 ( 63K) [Vulka ] use=4: ffn_moe_argsort-15 ( 64K) [Vulka ]
401-
//node #967 ( GET_ROWS): ffn_moe_weights-15 ( 4K) [Vulka ] use=1: ffn_moe_probs-15 (re ( 64K) [Vulka ] ffn_moe_topk-15 ( 63K) [Vulka ]
402-
//node #968 ( RESHAPE): ffn_moe_weights-15 ( ( 4K) [Vulka ] use=2: ffn_moe_weights-15 ( 4K) [Vulka ]
403-
//node #969 ( SUM_ROWS): ffn_moe_weights_sum- ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 4K) [Vulka ]
404-
//node #970 ( DIV): ffn_moe_weights_norm ( 4K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 4K) [Vulka ] ffn_moe_weights_sum- ( 0K) [Vulka ]
405-
//node #971 ( RESHAPE): ffn_moe_weights_norm ( 4K) [Vulka ] use=1: ffn_moe_weights_norm ( 4K) [Vulka ]
398+
//node #978 ( SOFT_MAX): ffn_moe_probs-15 ( 0K) [Vulka ] use=2: ffn_moe_logits-15 ( 0K) [Vulka ]
399+
//node #979 ( RESHAPE): ffn_moe_probs-15 (re ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ]
400+
//node #980 ( ARGSORT): ffn_moe_argsort-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ]
401+
//node #981 ( VIEW): ffn_moe_topk-15 ( 0K) [Vulka ] use=4: ffn_moe_argsort-15 ( 0K) [Vulka ]
402+
//node #982 ( GET_ROWS): ffn_moe_weights-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 (re ( 0K) [Vulka ] ffn_moe_topk-15 ( 0K) [Vulka ]
403+
//node #983 ( RESHAPE): ffn_moe_weights-15 ( ( 0K) [Vulka ] use=2: ffn_moe_weights-15 ( 0K) [Vulka ]
404+
//node #984 ( SUM_ROWS): ffn_moe_weights_sum- ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 0K) [Vulka ]
405+
//node #985 ( CLAMP): ffn_moe_weights_sum_ ( 0K) [Vulka ] use=1: ffn_moe_weights_sum- ( 0K) [Vulka ]
406+
//node #986 ( DIV): ffn_moe_weights_norm ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 0K) [Vulka ] ffn_moe_weights_sum_ ( 0K) [Vulka ]
407+
//node #987 ( RESHAPE): ffn_moe_weights_norm ( 0K) [Vulka ] use=1: ffn_moe_weights_norm ( 0K) [Vulka ]
406408
static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_norm_edges {
407409
{ 1, 0, 0 }, // reshape->src[0] == softmax
408410
{ 2, 0, 0 }, // argsort->src[0] == softmax
@@ -411,9 +413,10 @@ static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softma
411413
{ 4, 1, 3 }, // get_rows->src[1] == view
412414
{ 5, 0, 4 }, // reshape->src[0] == get_rows
413415
{ 6, 0, 5 }, // sum_rows->src[0] == reshape
414-
{ 7, 0, 5 }, // div->src[0] == reshape
415-
{ 7, 1, 6 }, // div->src[1] == sum_rows
416-
{ 8, 0, 7 }, // reshape->src[0] == div
416+
{ 7, 0, 6 }, // clamp->src[0] == sum_rows
417+
{ 8, 0, 5 }, // div->src[0] == reshape
418+
{ 8, 1, 7 }, // div->src[1] == clamp
419+
{ 9, 0, 8 }, // reshape->src[0] == div
417420
};
418421

419422
// same as early_softmax_norm but ending after the get_rows
@@ -1013,6 +1016,8 @@ static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
10131016
struct vk_op_topk_moe_push_constants {
10141017
uint32_t n_rows;
10151018
uint32_t n_expert_used;
1019+
float clamp_min;
1020+
float clamp_max;
10161021
};
10171022

10181023
struct vk_op_add_id_push_constants {
@@ -9632,7 +9637,7 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
96329637

96339638
topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
96349639
ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
9635-
ggml_tensor * weights = (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) ? cgraph->nodes[node_idx + 8] :
9640+
ggml_tensor * weights = (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) ? cgraph->nodes[node_idx + 9] :
96369641
(mode == TOPK_MOE_EARLY_SOFTMAX) ? cgraph->nodes[node_idx + 4] :
96379642
cgraph->nodes[node_idx + 5];
96389643
ggml_tensor * ids = (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] : cgraph->nodes[node_idx + 3];
@@ -9694,9 +9699,14 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
96949699
GGML_ASSERT(d_ids != nullptr);
96959700
}
96969701

9697-
vk_op_topk_moe_push_constants pc;
9702+
vk_op_topk_moe_push_constants pc {};
96989703
pc.n_rows = n_rows;
96999704
pc.n_expert_used = n_expert_used;
9705+
if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) {
9706+
ggml_tensor * clamp = cgraph->nodes[node_idx + 7];
9707+
pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
9708+
pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
9709+
}
97009710

97019711
GGML_ASSERT(n_expert_used <= n_experts);
97029712

@@ -12290,7 +12300,7 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
1229012300
switch (mode) {
1229112301
case TOPK_MOE_EARLY_SOFTMAX_NORM:
1229212302
softmax = cgraph->nodes[node_idx + 0];
12293-
weights = cgraph->nodes[node_idx + 8];
12303+
weights = cgraph->nodes[node_idx + 9];
1229412304
break;
1229512305
case TOPK_MOE_EARLY_SOFTMAX:
1229612306
softmax = cgraph->nodes[node_idx + 0];
@@ -12413,7 +12423,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1241312423
ctx->num_additional_fused_ops = num_adds - 1;
1241412424
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
1241512425
ctx->num_additional_fused_ops = 1;
12416-
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 8 }) &&
12426+
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
1241712427
ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
1241812428
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
1241912429
ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
@@ -12522,7 +12532,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1252212532
ctx->num_additional_fused_ops = num_adds - 1;
1252312533
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
1252412534
ctx->num_additional_fused_ops = 1;
12525-
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 8 }) &&
12535+
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
1252612536
ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
1252712537
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
1252812538
ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
@@ -12696,6 +12706,9 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
1269612706
if (keep_pattern(topk_moe_early_softmax_norm)) {
1269712707
continue;
1269812708
}
12709+
if (keep_pattern(topk_moe_early_softmax)) {
12710+
continue;
12711+
}
1269912712
if (keep_pattern(topk_moe_late_softmax)) {
1270012713
continue;
1270112714
}
@@ -12718,7 +12731,9 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
1271812731
continue;
1271912732
}
1272012733
// Don't pull forward nodes from fusion patterns
12721-
if (match_pattern(topk_moe_early_softmax_norm, j) || match_pattern(topk_moe_late_softmax, j)) {
12734+
if (match_pattern(topk_moe_early_softmax_norm, j) ||
12735+
match_pattern(topk_moe_early_softmax, j) ||
12736+
match_pattern(topk_moe_late_softmax, j)) {
1272212737
continue;
1272312738
}
1272412739
bool ok = true;

ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ layout (push_constant) uniform parameter
1111
{
1212
uint n_rows;
1313
uint n_expert_used;
14+
float clamp_min;
15+
float clamp_max;
1416
};
1517

1618
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
@@ -146,6 +148,7 @@ void main() {
146148

147149
if (with_norm) {
148150
wt_sum = subgroupAdd(wt_sum);
151+
wt_sum = clamp(wt_sum, clamp_min, clamp_max);
149152
const float inv_sum = 1.0f / wt_sum;
150153

151154
[[unroll]]

tests/test-backend-ops.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4712,6 +4712,7 @@ struct test_topk_moe: public test_case {
47124712
out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens);
47134713
ggml_tensor * weights_sum = ggml_sum_rows(ctx, out); // [1, n_tokens]
47144714

4715+
weights_sum = ggml_clamp(ctx, weights_sum, 6.103515625e-5, INFINITY);
47154716
out = ggml_div(ctx, out, weights_sum); // [n_expert_used, n_tokens]
47164717
out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens);
47174718
}

0 commit comments

Comments
 (0)