Skip to content

Commit 81853b5

Browse files
committed
Add ggml_check_edges
1 parent 6cccaef commit 81853b5

File tree

2 files changed

+79
-119
lines changed

2 files changed

+79
-119
lines changed

ggml/src/ggml-impl.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,7 @@ static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
682682
#endif
683683

684684
#ifdef __cplusplus
685+
#include <array>
685686
#include <initializer_list>
686687
#include <vector>
687688

@@ -697,6 +698,21 @@ inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
697698
return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size());
698699
}
699700

701+
// Return true if the edges in the graph match expectations.
702+
inline bool ggml_check_edges(const struct ggml_cgraph * cgraph,
703+
int start_idx,
704+
std::initializer_list<std::array<int, 3>> edges) {
705+
for (const auto &edge : edges) {
706+
int dst_node = edge[0];
707+
int src_idx = edge[1];
708+
int src_node = edge[2];
709+
if (cgraph->nodes[start_idx + dst_node]->src[src_idx] != cgraph->nodes[start_idx + src_node]) {
710+
return false;
711+
}
712+
}
713+
return true;
714+
}
715+
700716
// expose GGUF internals for test code
701717
GGML_API size_t gguf_type_size(enum gguf_type type);
702718
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 63 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,51 @@ static constexpr std::initializer_list<ggml_op> topk_moe_late_softmax { GGM
394394
GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
395395
GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
396396

397+
//node #963 ( SOFT_MAX): ffn_moe_probs-15 ( 64K) [Vulka ] use=2: ffn_moe_logits-15 ( 64K) [Vulka ]
398+
//node #964 ( RESHAPE): ffn_moe_probs-15 (re ( 64K) [Vulka ] use=1: ffn_moe_probs-15 ( 64K) [Vulka ]
399+
//node #965 ( ARGSORT): ffn_moe_argsort-15 ( 64K) [Vulka ] use=1: ffn_moe_probs-15 ( 64K) [Vulka ]
400+
//node #966 ( VIEW): ffn_moe_topk-15 ( 63K) [Vulka ] use=4: ffn_moe_argsort-15 ( 64K) [Vulka ]
401+
//node #967 ( GET_ROWS): ffn_moe_weights-15 ( 4K) [Vulka ] use=1: ffn_moe_probs-15 (re ( 64K) [Vulka ] ffn_moe_topk-15 ( 63K) [Vulka ]
402+
//node #968 ( RESHAPE): ffn_moe_weights-15 ( ( 4K) [Vulka ] use=2: ffn_moe_weights-15 ( 4K) [Vulka ]
403+
//node #969 ( SUM_ROWS): ffn_moe_weights_sum- ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 4K) [Vulka ]
404+
//node #970 ( DIV): ffn_moe_weights_norm ( 4K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 4K) [Vulka ] ffn_moe_weights_sum- ( 0K) [Vulka ]
405+
//node #971 ( RESHAPE): ffn_moe_weights_norm ( 4K) [Vulka ] use=1: ffn_moe_weights_norm ( 4K) [Vulka ]
406+
static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_norm_edges {
407+
{ 1, 0, 0 }, // reshape->src[0] == softmax
408+
{ 2, 0, 0 }, // argsort->src[0] == softmax
409+
{ 3, 0, 2 }, // view->src[0] == argsort
410+
{ 4, 0, 1 }, // get_rows->src[0] == reshape
411+
{ 4, 1, 3 }, // get_rows->src[1] == view
412+
{ 5, 0, 4 }, // reshape->src[0] == get_rows
413+
{ 6, 0, 5 }, // sum_rows->src[0] == reshape
414+
{ 7, 0, 5 }, // div->src[0] == reshape
415+
{ 7, 1, 6 }, // div->src[1] == sum_rows
416+
{ 8, 0, 7 }, // reshape->src[0] == div
417+
};
418+
419+
// same as early_softmax_norm but ending after the get_rows
420+
static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_edges {
421+
{ 1, 0, 0 }, // reshape->src[0] == softmax
422+
{ 2, 0, 0 }, // argsort->src[0] == softmax
423+
{ 3, 0, 2 }, // view->src[0] == argsort
424+
{ 4, 0, 1 }, // get_rows->src[0] == reshape
425+
{ 4, 1, 3 }, // get_rows->src[1] == view
426+
};
427+
428+
//node #652 ( ARGSORT): ffn_moe_argsort-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 ( 0K) [Vulka ]
429+
//node #653 ( VIEW): ffn_moe_topk-11 ( 0K) [Vulka ] use=7: ffn_moe_argsort-11 ( 0K) [Vulka ]
430+
//node #654 ( GET_ROWS): ffn_moe_weights-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 (re ( 0K) [Vulka ] ffn_moe_topk-11 ( 0K) [Vulka ]
431+
//node #655 ( RESHAPE): ffn_moe_weights-11 ( ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( 0K) [Vulka ]
432+
//node #656 ( SOFT_MAX): node_656 ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( ( 0K) [Vulka ]
433+
//node #657 ( RESHAPE): ffn_moe_weights_soft ( 0K) [Vulka ] use=1: node_656 ( 0K) [Vulka ]
434+
static constexpr std::initializer_list<std::array<int, 3>> topk_moe_late_softmax_edges {
435+
{ 1, 0, 0 }, // view->src[0] == argsort
436+
{ 2, 1, 1 }, // get_rows->src[1] == view
437+
{ 3, 0, 2 }, // reshape->src[0] == get_rows
438+
{ 4, 0, 3 }, // soft_max->src[0] == reshape
439+
{ 5, 0, 4 }, // reshape->src[0] == soft_max
440+
};
441+
397442
enum topk_moe_mode {
398443
TOPK_MOE_EARLY_SOFTMAX,
399444
TOPK_MOE_EARLY_SOFTMAX_NORM,
@@ -12226,38 +12271,14 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
1222612271

1222712272
switch (mode) {
1222812273
case TOPK_MOE_EARLY_SOFTMAX_NORM:
12229-
if (node_idx + (int)topk_moe_early_softmax_norm.size() > cgraph->n_nodes) {
12230-
return false;
12231-
}
12232-
for (size_t i = 0; i < topk_moe_early_softmax_norm.size(); ++i) {
12233-
if (cgraph->nodes[node_idx + i]->op != topk_moe_early_softmax_norm.begin()[i]) {
12234-
return false;
12235-
}
12236-
}
1223712274
softmax = cgraph->nodes[node_idx + 0];
1223812275
weights = cgraph->nodes[node_idx + 8];
1223912276
break;
1224012277
case TOPK_MOE_EARLY_SOFTMAX:
12241-
if (node_idx + (int)topk_moe_early_softmax.size() > cgraph->n_nodes) {
12242-
return false;
12243-
}
12244-
for (size_t i = 0; i < topk_moe_early_softmax.size(); ++i) {
12245-
if (cgraph->nodes[node_idx + i]->op != topk_moe_early_softmax.begin()[i]) {
12246-
return false;
12247-
}
12248-
}
1224912278
softmax = cgraph->nodes[node_idx + 0];
1225012279
weights = cgraph->nodes[node_idx + 4];
1225112280
break;
1225212281
case TOPK_MOE_LATE_SOFTMAX:
12253-
if (node_idx + (int)topk_moe_late_softmax.size() > cgraph->n_nodes) {
12254-
return false;
12255-
}
12256-
for (size_t i = 0; i < topk_moe_late_softmax.size(); ++i) {
12257-
if (cgraph->nodes[node_idx + i]->op != topk_moe_late_softmax.begin()[i]) {
12258-
return false;
12259-
}
12260-
}
1226112282
softmax = cgraph->nodes[node_idx + 4];
1226212283
weights = cgraph->nodes[node_idx + 5];
1226312284
break;
@@ -12289,95 +12310,6 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
1228912310
return false;
1229012311
}
1229112312

12292-
// Check that the nodes don't have any unexpected uses
12293-
if (mode == TOPK_MOE_LATE_SOFTMAX) {
12294-
const ggml_tensor * argsort = cgraph->nodes[node_idx + 0];
12295-
const ggml_tensor * view = cgraph->nodes[node_idx + 1];
12296-
const ggml_tensor * get_rows = cgraph->nodes[node_idx + 2];
12297-
const ggml_tensor * reshape3 = cgraph->nodes[node_idx + 3];
12298-
// softmax is 4
12299-
const ggml_tensor * reshape5 = cgraph->nodes[node_idx + 5];
12300-
12301-
// argsort is used by view
12302-
if (ggml_node_get_use_count(cgraph, node_idx + 0) != 1 ||
12303-
view->src[0] != argsort) {
12304-
return false;
12305-
}
12306-
// view is written, we can skip checking it
12307-
12308-
// get_rows is used by reshape3
12309-
if (ggml_node_get_use_count(cgraph, node_idx + 2) != 1 ||
12310-
reshape3->src[0] != get_rows) {
12311-
return false;
12312-
}
12313-
12314-
// reshape3 is used by softmax
12315-
if (ggml_node_get_use_count(cgraph, node_idx + 3) != 1 ||
12316-
softmax->src[0] != reshape3) {
12317-
return false;
12318-
}
12319-
12320-
// softmax is used by reshape5
12321-
if (ggml_node_get_use_count(cgraph, node_idx + 4) != 1 ||
12322-
reshape5->src[0] != softmax) {
12323-
return false;
12324-
}
12325-
} else {
12326-
bool with_norm = mode == TOPK_MOE_EARLY_SOFTMAX_NORM;
12327-
const ggml_tensor * reshape1 = cgraph->nodes[node_idx + 1];
12328-
const ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
12329-
const ggml_tensor * view = cgraph->nodes[node_idx + 3];
12330-
const ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
12331-
const ggml_tensor * reshape5 = with_norm ? cgraph->nodes[node_idx + 5] : nullptr;
12332-
const ggml_tensor * sum_rows = with_norm ? cgraph->nodes[node_idx + 6] : nullptr;
12333-
const ggml_tensor * div = with_norm ? cgraph->nodes[node_idx + 7] : nullptr;
12334-
const ggml_tensor * reshape8 = with_norm ? cgraph->nodes[node_idx + 8] : nullptr;
12335-
12336-
// softmax is used by reshape and argsort
12337-
if (ggml_node_get_use_count(cgraph, node_idx) != 2 ||
12338-
reshape1->src[0] != softmax ||
12339-
argsort->src[0] != softmax) {
12340-
return false;
12341-
}
12342-
// reshape is used by get_rows
12343-
if (ggml_node_get_use_count(cgraph, node_idx + 1) != 1 ||
12344-
get_rows->src[0] != reshape1) {
12345-
return false;
12346-
}
12347-
// argsort is used by view
12348-
if (ggml_node_get_use_count(cgraph, node_idx + 2) != 1 ||
12349-
view->src[0] != argsort) {
12350-
return false;
12351-
}
12352-
// view is written (via argsort), we can skip checking it
12353-
12354-
if (with_norm) {
12355-
// get_rows is used by reshape
12356-
if (ggml_node_get_use_count(cgraph, node_idx + 4) != 1 ||
12357-
reshape5->src[0] != get_rows) {
12358-
return false;
12359-
}
12360-
12361-
// reshape is used by sum_rows and div
12362-
if (ggml_node_get_use_count(cgraph, node_idx + 5) != 2 ||
12363-
sum_rows->src[0] != reshape5 ||
12364-
div->src[0] != reshape5) {
12365-
return false;
12366-
}
12367-
12368-
// sum_rows is used by div
12369-
if (ggml_node_get_use_count(cgraph, node_idx + 6) != 1 ||
12370-
div->src[1] != sum_rows) {
12371-
return false;
12372-
}
12373-
12374-
// div/reshape are written
12375-
if (reshape8->src[0] != div) {
12376-
return false;
12377-
}
12378-
}
12379-
}
12380-
1238112313
if (!ctx->device->subgroup_arithmetic ||
1238212314
!ctx->device->subgroup_shuffle ||
1238312315
!ctx->device->subgroup_require_full_support ||
@@ -12463,11 +12395,17 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1246312395
ctx->num_additional_fused_ops = num_adds - 1;
1246412396
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
1246512397
ctx->num_additional_fused_ops = 1;
12466-
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
12398+
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 8 }) &&
12399+
ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
12400+
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
1246712401
ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
12468-
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
12402+
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
12403+
ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
12404+
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
1246912405
ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
12470-
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
12406+
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
12407+
ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
12408+
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
1247112409
ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
1247212410
}
1247312411
}
@@ -12566,11 +12504,17 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1256612504
ctx->num_additional_fused_ops = num_adds - 1;
1256712505
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
1256812506
ctx->num_additional_fused_ops = 1;
12569-
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
12507+
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 8 }) &&
12508+
ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
12509+
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
1257012510
ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
12571-
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
12511+
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
12512+
ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
12513+
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
1257212514
ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
12573-
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
12515+
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
12516+
ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
12517+
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
1257412518
ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
1257512519
}
1257612520
}

0 commit comments

Comments
 (0)