11#include " moe-expert-reduce.cuh"
22
3- // This kernel is fusion of the expert weight reduce, common in MoE models
3+ // This kernel is a fusion of the expert weight reduce, common in MoE models
44
55template <int n_expert_used_template>
66__global__ void moe_expert_reduce_cuda (const float * __restrict__ experts,
77 const float * __restrict__ weights,
88 float * __restrict__ dst,
99 const int n_expert_used,
1010 const int n_cols) {
11- const int row = blockIdx .x ;
12- const int n_expert_used_t = n_expert_used_template == 0 ? n_expert_used : n_expert_used_template;
13-
11+ const int row = blockIdx .x ;
1412 const int col = blockIdx .y * blockDim .x + threadIdx .x ;
1513 if (col >= n_cols) {
1614 return ;
@@ -22,7 +20,7 @@ __global__ void moe_expert_reduce_cuda(const float * __restrict__ experts,
2220
2321 float acc = 0 .f ;
2422 if constexpr (n_expert_used_template == 0 ) {
25- for (int expert = 0 ; expert < n_expert_used_t ; ++expert) {
23+ for (int expert = 0 ; expert < n_expert_used ; ++expert) {
2624 ggml_cuda_mad (acc, experts[col], weights[expert]);
2725 experts += n_cols;
2826 }
@@ -98,37 +96,34 @@ static void launch_moe_expert_reduce(ggml_backend_cuda_context & ctx,
9896}
9997
10098bool ggml_cuda_should_use_moe_expert_reduce (const ggml_cgraph * cgraph, int start_index, int end_index) {
101- const ggml_tensor * experts = cgraph->nodes [start_index];
102- if (experts->op != GGML_OP_MUL) {
99+ const ggml_tensor * mul = cgraph->nodes [start_index];
100+
101+ if (mul->op != GGML_OP_MUL || !ggml_is_contiguous (mul->src [0 ]) || !ggml_is_contiguous (mul->src [1 ])) {
103102 return false ;
104103 }
105104
106105 int current_node = start_index + 1 ;
107106 size_t current_offset = 0 ;
108107
109- const ggml_tensor * view_nodes[32 ];
110- int num_views = 0 ;
108+ std::vector<const ggml_tensor *> view_nodes;
111109 // check if all are views of the expert in increasing order
112110 while (current_node < end_index && cgraph->nodes [current_node]->op == GGML_OP_VIEW) {
113111 const ggml_tensor * node = cgraph->nodes [current_node];
114- if (node->view_src != experts ) {
112+ if (node->view_src != mul ) {
115113 return false ;
116114 }
117115 if (node->view_offs < current_offset) {
118116 return false ;
119117 }
120118 current_offset = node->view_offs ;
121119 current_node++;
122- view_nodes[num_views++] = node;
123-
124- if (num_views >= 32 ) {
125- return false ;
126- }
120+ view_nodes.push_back (node);
127121 }
128122
129123 // check if all the adds are in increasing order
130- const ggml_tensor * prev_add_src = view_nodes[0 ];
124+ const ggml_tensor * prev_add_src = view_nodes. size () ? view_nodes [0 ] : nullptr ;
131125 int num_adds = 0 ;
126+ int num_views = view_nodes.size ();
132127 while (current_node < end_index && cgraph->nodes [current_node]->op == GGML_OP_ADD) {
133128 const ggml_tensor * add_node = cgraph->nodes [current_node];
134129
0 commit comments