@@ -394,6 +394,51 @@ static constexpr std::initializer_list<ggml_op> topk_moe_late_softmax { GGM
394394 GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
395395 GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
396396
397+ //node #963 ( SOFT_MAX): ffn_moe_probs-15 ( 64K) [Vulka ] use=2: ffn_moe_logits-15 ( 64K) [Vulka ]
398+ //node #964 ( RESHAPE): ffn_moe_probs-15 (re ( 64K) [Vulka ] use=1: ffn_moe_probs-15 ( 64K) [Vulka ]
399+ //node #965 ( ARGSORT): ffn_moe_argsort-15 ( 64K) [Vulka ] use=1: ffn_moe_probs-15 ( 64K) [Vulka ]
400+ //node #966 ( VIEW): ffn_moe_topk-15 ( 63K) [Vulka ] use=4: ffn_moe_argsort-15 ( 64K) [Vulka ]
401+ //node #967 ( GET_ROWS): ffn_moe_weights-15 ( 4K) [Vulka ] use=1: ffn_moe_probs-15 (re ( 64K) [Vulka ] ffn_moe_topk-15 ( 63K) [Vulka ]
402+ //node #968 ( RESHAPE): ffn_moe_weights-15 ( ( 4K) [Vulka ] use=2: ffn_moe_weights-15 ( 4K) [Vulka ]
403+ //node #969 ( SUM_ROWS): ffn_moe_weights_sum- ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 4K) [Vulka ]
404+ //node #970 ( DIV): ffn_moe_weights_norm ( 4K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 4K) [Vulka ] ffn_moe_weights_sum- ( 0K) [Vulka ]
405+ //node #971 ( RESHAPE): ffn_moe_weights_norm ( 4K) [Vulka ] use=1: ffn_moe_weights_norm ( 4K) [Vulka ]
406+ static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_norm_edges {
407+ { 1, 0, 0 }, // reshape->src[0] == softmax
408+ { 2, 0, 0 }, // argsort->src[0] == softmax
409+ { 3, 0, 2 }, // view->src[0] == argsort
410+ { 4, 0, 1 }, // get_rows->src[0] == reshape
411+ { 4, 1, 3 }, // get_rows->src[1] == view
412+ { 5, 0, 4 }, // reshape->src[0] == get_rows
413+ { 6, 0, 5 }, // sum_rows->src[0] == reshape
414+ { 7, 0, 5 }, // div->src[0] == reshape
415+ { 7, 1, 6 }, // div->src[1] == sum_rows
416+ { 8, 0, 7 }, // reshape->src[0] == div
417+ };
418+
419+ // same as early_softmax_norm but ending after the get_rows
420+ static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_edges {
421+ { 1, 0, 0 }, // reshape->src[0] == softmax
422+ { 2, 0, 0 }, // argsort->src[0] == softmax
423+ { 3, 0, 2 }, // view->src[0] == argsort
424+ { 4, 0, 1 }, // get_rows->src[0] == reshape
425+ { 4, 1, 3 }, // get_rows->src[1] == view
426+ };
427+
428+ //node #652 ( ARGSORT): ffn_moe_argsort-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 ( 0K) [Vulka ]
429+ //node #653 ( VIEW): ffn_moe_topk-11 ( 0K) [Vulka ] use=7: ffn_moe_argsort-11 ( 0K) [Vulka ]
430+ //node #654 ( GET_ROWS): ffn_moe_weights-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 (re ( 0K) [Vulka ] ffn_moe_topk-11 ( 0K) [Vulka ]
431+ //node #655 ( RESHAPE): ffn_moe_weights-11 ( ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( 0K) [Vulka ]
432+ //node #656 ( SOFT_MAX): node_656 ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( ( 0K) [Vulka ]
433+ //node #657 ( RESHAPE): ffn_moe_weights_soft ( 0K) [Vulka ] use=1: node_656 ( 0K) [Vulka ]
434+ static constexpr std::initializer_list<std::array<int, 3>> topk_moe_late_softmax_edges {
435+ { 1, 0, 0 }, // view->src[0] == argsort
436+ { 2, 1, 1 }, // get_rows->src[1] == view
437+ { 3, 0, 2 }, // reshape->src[0] == get_rows
438+ { 4, 0, 3 }, // soft_max->src[0] == reshape
439+ { 5, 0, 4 }, // reshape->src[0] == soft_max
440+ };
441+
397442enum topk_moe_mode {
398443 TOPK_MOE_EARLY_SOFTMAX,
399444 TOPK_MOE_EARLY_SOFTMAX_NORM,
@@ -12226,38 +12271,14 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
1222612271
1222712272 switch (mode) {
1222812273 case TOPK_MOE_EARLY_SOFTMAX_NORM:
12229- if (node_idx + (int)topk_moe_early_softmax_norm.size() > cgraph->n_nodes) {
12230- return false;
12231- }
12232- for (size_t i = 0; i < topk_moe_early_softmax_norm.size(); ++i) {
12233- if (cgraph->nodes[node_idx + i]->op != topk_moe_early_softmax_norm.begin()[i]) {
12234- return false;
12235- }
12236- }
1223712274 softmax = cgraph->nodes[node_idx + 0];
1223812275 weights = cgraph->nodes[node_idx + 8];
1223912276 break;
1224012277 case TOPK_MOE_EARLY_SOFTMAX:
12241- if (node_idx + (int)topk_moe_early_softmax.size() > cgraph->n_nodes) {
12242- return false;
12243- }
12244- for (size_t i = 0; i < topk_moe_early_softmax.size(); ++i) {
12245- if (cgraph->nodes[node_idx + i]->op != topk_moe_early_softmax.begin()[i]) {
12246- return false;
12247- }
12248- }
1224912278 softmax = cgraph->nodes[node_idx + 0];
1225012279 weights = cgraph->nodes[node_idx + 4];
1225112280 break;
1225212281 case TOPK_MOE_LATE_SOFTMAX:
12253- if (node_idx + (int)topk_moe_late_softmax.size() > cgraph->n_nodes) {
12254- return false;
12255- }
12256- for (size_t i = 0; i < topk_moe_late_softmax.size(); ++i) {
12257- if (cgraph->nodes[node_idx + i]->op != topk_moe_late_softmax.begin()[i]) {
12258- return false;
12259- }
12260- }
1226112282 softmax = cgraph->nodes[node_idx + 4];
1226212283 weights = cgraph->nodes[node_idx + 5];
1226312284 break;
@@ -12289,95 +12310,6 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
1228912310 return false;
1229012311 }
1229112312
12292- // Check that the nodes don't have any unexpected uses
12293- if (mode == TOPK_MOE_LATE_SOFTMAX) {
12294- const ggml_tensor * argsort = cgraph->nodes[node_idx + 0];
12295- const ggml_tensor * view = cgraph->nodes[node_idx + 1];
12296- const ggml_tensor * get_rows = cgraph->nodes[node_idx + 2];
12297- const ggml_tensor * reshape3 = cgraph->nodes[node_idx + 3];
12298- // softmax is 4
12299- const ggml_tensor * reshape5 = cgraph->nodes[node_idx + 5];
12300-
12301- // argsort is used by view
12302- if (ggml_node_get_use_count(cgraph, node_idx + 0) != 1 ||
12303- view->src[0] != argsort) {
12304- return false;
12305- }
12306- // view is written, we can skip checking it
12307-
12308- // get_rows is used by reshape3
12309- if (ggml_node_get_use_count(cgraph, node_idx + 2) != 1 ||
12310- reshape3->src[0] != get_rows) {
12311- return false;
12312- }
12313-
12314- // reshape3 is used by softmax
12315- if (ggml_node_get_use_count(cgraph, node_idx + 3) != 1 ||
12316- softmax->src[0] != reshape3) {
12317- return false;
12318- }
12319-
12320- // softmax is used by reshape5
12321- if (ggml_node_get_use_count(cgraph, node_idx + 4) != 1 ||
12322- reshape5->src[0] != softmax) {
12323- return false;
12324- }
12325- } else {
12326- bool with_norm = mode == TOPK_MOE_EARLY_SOFTMAX_NORM;
12327- const ggml_tensor * reshape1 = cgraph->nodes[node_idx + 1];
12328- const ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
12329- const ggml_tensor * view = cgraph->nodes[node_idx + 3];
12330- const ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
12331- const ggml_tensor * reshape5 = with_norm ? cgraph->nodes[node_idx + 5] : nullptr;
12332- const ggml_tensor * sum_rows = with_norm ? cgraph->nodes[node_idx + 6] : nullptr;
12333- const ggml_tensor * div = with_norm ? cgraph->nodes[node_idx + 7] : nullptr;
12334- const ggml_tensor * reshape8 = with_norm ? cgraph->nodes[node_idx + 8] : nullptr;
12335-
12336- // softmax is used by reshape and argsort
12337- if (ggml_node_get_use_count(cgraph, node_idx) != 2 ||
12338- reshape1->src[0] != softmax ||
12339- argsort->src[0] != softmax) {
12340- return false;
12341- }
12342- // reshape is used by get_rows
12343- if (ggml_node_get_use_count(cgraph, node_idx + 1) != 1 ||
12344- get_rows->src[0] != reshape1) {
12345- return false;
12346- }
12347- // argsort is used by view
12348- if (ggml_node_get_use_count(cgraph, node_idx + 2) != 1 ||
12349- view->src[0] != argsort) {
12350- return false;
12351- }
12352- // view is written (via argsort), we can skip checking it
12353-
12354- if (with_norm) {
12355- // get_rows is used by reshape
12356- if (ggml_node_get_use_count(cgraph, node_idx + 4) != 1 ||
12357- reshape5->src[0] != get_rows) {
12358- return false;
12359- }
12360-
12361- // reshape is used by sum_rows and div
12362- if (ggml_node_get_use_count(cgraph, node_idx + 5) != 2 ||
12363- sum_rows->src[0] != reshape5 ||
12364- div->src[0] != reshape5) {
12365- return false;
12366- }
12367-
12368- // sum_rows is used by div
12369- if (ggml_node_get_use_count(cgraph, node_idx + 6) != 1 ||
12370- div->src[1] != sum_rows) {
12371- return false;
12372- }
12373-
12374- // div/reshape are written
12375- if (reshape8->src[0] != div) {
12376- return false;
12377- }
12378- }
12379- }
12380-
1238112313 if (!ctx->device->subgroup_arithmetic ||
1238212314 !ctx->device->subgroup_shuffle ||
1238312315 !ctx->device->subgroup_require_full_support ||
@@ -12463,11 +12395,17 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1246312395 ctx->num_additional_fused_ops = num_adds - 1;
1246412396 } else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
1246512397 ctx->num_additional_fused_ops = 1;
12466- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
12398+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 8 }) &&
12399+ ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
12400+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
1246712401 ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
12468- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
12402+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
12403+ ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
12404+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
1246912405 ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
12470- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
12406+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
12407+ ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
12408+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
1247112409 ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
1247212410 }
1247312411 }
@@ -12566,11 +12504,17 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1256612504 ctx->num_additional_fused_ops = num_adds - 1;
1256712505 } else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
1256812506 ctx->num_additional_fused_ops = 1;
12569- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
12507+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 8 }) &&
12508+ ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
12509+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
1257012510 ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
12571- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
12511+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
12512+ ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
12513+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
1257212514 ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
12573- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
12515+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
12516+ ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
12517+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
1257412518 ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
1257512519 }
1257612520 }
0 commit comments