@@ -3173,7 +3173,25 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
31733173 ggml_cuda_op_relu (ctx, dst);
31743174 break ;
31753175 case GGML_UNARY_OP_SIGMOID:
3176- ggml_cuda_op_sigmoid (ctx, dst);
3176+ if (i + 5 < cgraph->n_nodes &&
3177+ cgraph->nodes [i+1 ]->op == GGML_OP_RESHAPE &&
3178+ cgraph->nodes [i+2 ]->op == GGML_OP_ADD &&
3179+ cgraph->nodes [i+3 ]->op == GGML_OP_ARGSORT &&
3180+ cgraph->nodes [i+4 ]->op == GGML_OP_VIEW &&
3181+ cgraph->nodes [i+5 ]->op == GGML_OP_GET_ROWS) {
3182+ cuda_glm45moe_experts (ctx, cgraph->nodes [i+5 ], cgraph->nodes [i+4 ]);
3183+ i += 5 ;
3184+ }
3185+ else if (i + 4 < cgraph->n_nodes &&
3186+ cgraph->nodes [i+1 ]->op == GGML_OP_RESHAPE &&
3187+ cgraph->nodes [i+2 ]->op == GGML_OP_ADD &&
3188+ cgraph->nodes [i+3 ]->op == GGML_OP_GROUPED_TOPK &&
3189+ cgraph->nodes [i+4 ]->op == GGML_OP_GET_ROWS) {
3190+ cuda_bailingmoev2_experts (ctx, cgraph->nodes [i+4 ], cgraph->nodes [i+4 ]);
3191+ i += 4 ;
3192+ } else {
3193+ ggml_cuda_op_sigmoid (ctx, dst);
3194+ }
31773195 break ;
31783196 case GGML_UNARY_OP_HARDSIGMOID:
31793197 ggml_cuda_op_hardsigmoid (ctx, dst);
@@ -3315,10 +3333,28 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
33153333 ggml_cuda_op_pool2d (ctx, dst);
33163334 break ;
33173335 case GGML_OP_SUM_ROWS:
3318- ggml_cuda_op_sum_rows (ctx, dst);
3336+ if (i + 1 < cgraph->n_nodes &&
3337+ cgraph->nodes [i+1 ]->op == GGML_OP_DIV &&
3338+ cgraph->nodes [i+1 ]->src [1 ] == dst &&
3339+ cgraph->nodes [i+1 ]->src [0 ] == dst->src [0 ]) {
3340+ ggml_cuda_op_sum_rows_div (ctx, cgraph->nodes [i+1 ]);
3341+ ++i;
3342+ } else {
3343+ ggml_cuda_op_sum_rows (ctx, dst);
3344+ }
33193345 break ;
33203346 case GGML_OP_ARGSORT:
3321- ggml_cuda_op_argsort (ctx, dst);
3347+ if (i + 5 < cgraph->n_nodes &&
3348+ cgraph->nodes [i+1 ]->op == GGML_OP_VIEW &&
3349+ cgraph->nodes [i+2 ]->op == GGML_OP_GET_ROWS &&
3350+ cgraph->nodes [i+3 ]->op == GGML_OP_RESHAPE &&
3351+ cgraph->nodes [i+4 ]->op == GGML_OP_SOFT_MAX &&
3352+ cgraph->nodes [i+5 ]->op == GGML_OP_RESHAPE) {
3353+ cuda_openai_experts (ctx, dst, cgraph->nodes [i+4 ]);
3354+ i += 5 ;
3355+ } else {
3356+ ggml_cuda_op_argsort (ctx, dst);
3357+ }
33223358 break ;
33233359 case GGML_OP_ARGSORT_THRESH:
33243360 ggml_cuda_op_argsort_thresh (ctx, dst);
0 commit comments