@@ -3094,12 +3094,28 @@ static void ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
30943094
30953095}
30963096
3097+ static inline bool ops_are_same_device (const ggml_cgraph * cgraph, int first, int last) {
3098+ if (last <= first) return true ;
3099+ int device = ((const ggml_backend_cuda_buffer_context *)cgraph->nodes [first]->buffer ->context )->device ;
3100+ for (int i = first; i <= last; ++i) {
3101+ auto node = cgraph->nodes [i];
3102+ if (((const ggml_backend_cuda_buffer_context *)node->buffer ->context )->device != device) return false ;
3103+ for (int j = 0 ; j < GGML_MAX_SRC; ++j) {
3104+ if (!node->src [j] || !node->src [j]->buffer ) continue ;
3105+ if (((const ggml_backend_cuda_buffer_context *)node->src [j]->buffer ->context )->device != device) return false ;
3106+ }
3107+ }
3108+ return true ;
3109+ }
3110+
30973111static bool ggml_cuda_compute_forward (ggml_backend_cuda_context & ctx, struct ggml_tensor * dst, const ggml_cgraph * cgraph, int & i) {
30983112 // why is this here instead of mul_mat?
30993113 if (dst->src [0 ] != nullptr && ggml_backend_buffer_is_cuda_split (dst->src [0 ]->buffer )) {
31003114 ggml_cuda_set_peer_access (dst->src [1 ]->ne [1 ], ctx.device );
31013115 }
31023116
3117+ #define ENABLE_FUSION true
3118+
31033119#if IK_PRINT_TIMING
31043120 int64_t tim1 = ggml_time_us ();
31053121#endif
@@ -3129,17 +3145,32 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
31293145 ggml_cuda_dup (ctx, dst);
31303146 break ;
31313147 case GGML_OP_ADD:
3132- if (i + 1 < cgraph->n_nodes &&
3148+ if (ENABLE_FUSION && i + 2 < cgraph->n_nodes &&
3149+ cgraph->nodes [i+1 ]->op == GGML_OP_ADD &&
3150+ cgraph->nodes [i+2 ]->op == GGML_OP_FUSED_RMS_NORM &&
3151+ ggml_is_contiguous (dst->src [0 ]) &&
3152+ ggml_is_contiguous (dst->src [1 ]) &&
3153+ ggml_are_same_shape (dst->src [0 ], dst->src [1 ]) &&
3154+ dst == cgraph->nodes [i+1 ]->src [0 ] &&
3155+ ggml_is_contiguous (cgraph->nodes [i+1 ]->src [1 ]) &&
3156+ ggml_are_same_shape (dst, cgraph->nodes [i+1 ]->src [1 ]) &&
3157+ cgraph->nodes [i+1 ] == cgraph->nodes [i+2 ]->src [0 ] &&
3158+ ops_are_same_device (cgraph, i, i+2 )) {
3159+ // printf("Fusing add->add->fused_rms of %s, %s, %s\n", dst->name, cgraph->nodes[i+1]->name, cgraph->nodes[i+2]->name);
3160+ ggml_cuda_op_fused_add_add_rms_norm (ctx, dst, cgraph->nodes [i+1 ], cgraph->nodes [i+2 ]);
3161+ i += 2 ;
3162+ }
3163+ else if (ENABLE_FUSION && i + 1 < cgraph->n_nodes &&
31333164 cgraph->nodes [i+1 ]->op == GGML_OP_FUSED_RMS_NORM &&
31343165 ggml_is_contiguous (dst->src [0 ]) &&
31353166 ggml_is_contiguous (dst->src [1 ]) &&
3136- ggml_are_same_shape (dst->src [0 ], dst->src [1 ])) {
3167+ ggml_are_same_shape (dst->src [0 ], dst->src [1 ]) &&
3168+ dst == cgraph->nodes [i+1 ]->src [0 ] && ops_are_same_device (cgraph, i, i+1 )) {
31373169 ggml_cuda_op_fused_add_rms_norm (ctx, dst, cgraph->nodes [i+1 ]);
31383170 ++i;
31393171 } else {
31403172 ggml_cuda_op_add (ctx, dst);
31413173 }
3142- // ggml_cuda_op_add(ctx, dst);
31433174 break ;
31443175 case GGML_OP_ADD_ID:
31453176 ggml_cuda_op_add_id (ctx, dst);
@@ -3183,22 +3214,27 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
31833214 ggml_cuda_op_relu (ctx, dst);
31843215 break ;
31853216 case GGML_UNARY_OP_SIGMOID:
3186- if (i + 5 < cgraph->n_nodes &&
3217+ if (ENABLE_FUSION && i + 5 < cgraph->n_nodes &&
31873218 cgraph->nodes [i+1 ]->op == GGML_OP_RESHAPE &&
31883219 cgraph->nodes [i+2 ]->op == GGML_OP_ADD &&
31893220 cgraph->nodes [i+3 ]->op == GGML_OP_ARGSORT &&
31903221 cgraph->nodes [i+4 ]->op == GGML_OP_VIEW &&
3191- cgraph->nodes [i+5 ]->op == GGML_OP_GET_ROWS) {
3222+ cgraph->nodes [i+5 ]->op == GGML_OP_GET_ROWS && ops_are_same_device (cgraph, i, i+ 5 ) ) {
31923223 cuda_glm45moe_experts (ctx, cgraph->nodes [i+5 ], cgraph->nodes [i+4 ]);
31933224 i += 5 ;
31943225 }
3195- else if (i + 4 < cgraph->n_nodes &&
3226+ else if (ENABLE_FUSION && i + 4 < cgraph->n_nodes &&
31963227 cgraph->nodes [i+1 ]->op == GGML_OP_RESHAPE &&
31973228 cgraph->nodes [i+2 ]->op == GGML_OP_ADD &&
31983229 cgraph->nodes [i+3 ]->op == GGML_OP_GROUPED_TOPK &&
3199- cgraph->nodes [i+4 ]->op == GGML_OP_GET_ROWS) {
3230+ cgraph->nodes [i+4 ]->op == GGML_OP_GET_ROWS && ops_are_same_device (cgraph, i, i+ 4 ) ) {
32003231 cuda_bailingmoev2_experts (ctx, cgraph->nodes [i+4 ], cgraph->nodes [i+3 ]);
32013232 i += 4 ;
3233+ } else if (ENABLE_FUSION && i + 2 < cgraph->n_nodes &&
3234+ cgraph->nodes [i+1 ]->op == GGML_OP_RESHAPE &&
3235+ cgraph->nodes [i+2 ]->op == GGML_OP_ADD && ops_are_same_device (cgraph, i, i+2 )) {
3236+ ggml_cuda_op_biased_sigmoid (ctx, cgraph->nodes [i+2 ]);
3237+ i += 2 ;
32023238 } else {
32033239 ggml_cuda_op_sigmoid (ctx, dst);
32043240 }
@@ -3309,12 +3345,13 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
33093345 ggml_cuda_op_diag_mask_inf (ctx, dst);
33103346 break ;
33113347 case GGML_OP_SOFT_MAX:
3312- if (i + 4 < cgraph->n_nodes &&
3348+ if (ENABLE_FUSION && i + 4 < cgraph->n_nodes &&
33133349 cgraph->nodes [i+1 ]->op == GGML_OP_RESHAPE &&
33143350 cgraph->nodes [i+2 ]->op == GGML_OP_ARGSORT &&
33153351 cgraph->nodes [i+3 ]->op == GGML_OP_VIEW &&
33163352 cgraph->nodes [i+4 ]->op == GGML_OP_GET_ROWS &&
3317- ggml_cuda_should_use_topk_moe (cgraph->nodes [i], cgraph->nodes [i+4 ])) {
3353+ ggml_cuda_should_use_topk_moe (cgraph->nodes [i], cgraph->nodes [i+4 ]) &&
3354+ ops_are_same_device (cgraph, i, i+4 )) {
33183355 ggml_cuda_op_topk_moe (ctx, cgraph->nodes [i], cgraph->nodes [i+4 ], cgraph->nodes [i+3 ]);
33193356 i += 4 ;
33203357 } else {
@@ -3343,23 +3380,32 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
33433380 ggml_cuda_op_pool2d (ctx, dst);
33443381 break ;
33453382 case GGML_OP_SUM_ROWS:
3346- if (i + 1 < cgraph->n_nodes &&
3383+ if (ENABLE_FUSION && i + 2 < cgraph->n_nodes &&
3384+ cgraph->nodes [i+1 ]->op == GGML_OP_SCALE &&
3385+ cgraph->nodes [i+2 ]->op == GGML_OP_DIV &&
3386+ cgraph->nodes [i+1 ]->src [0 ] == dst &&
3387+ cgraph->nodes [i+2 ]->src [1 ] == cgraph->nodes [i+1 ] &&
3388+ cgraph->nodes [i+2 ]->src [0 ] == dst->src [0 ] && ops_are_same_device (cgraph, i, i+2 )) {
3389+ ggml_cuda_op_sum_rows_div (ctx, cgraph->nodes [i+2 ]);
3390+ i += 2 ;
3391+ }
3392+ else if (ENABLE_FUSION && i + 1 < cgraph->n_nodes &&
33473393 cgraph->nodes [i+1 ]->op == GGML_OP_DIV &&
33483394 cgraph->nodes [i+1 ]->src [1 ] == dst &&
3349- cgraph->nodes [i+1 ]->src [0 ] == dst->src [0 ]) {
3395+ cgraph->nodes [i+1 ]->src [0 ] == dst->src [0 ] && ops_are_same_device (cgraph, i, i+ 1 ) ) {
33503396 ggml_cuda_op_sum_rows_div (ctx, cgraph->nodes [i+1 ]);
33513397 ++i;
33523398 } else {
33533399 ggml_cuda_op_sum_rows (ctx, dst);
33543400 }
33553401 break ;
33563402 case GGML_OP_ARGSORT:
3357- if (i + 5 < cgraph->n_nodes &&
3403+ if (ENABLE_FUSION && i + 5 < cgraph->n_nodes &&
33583404 cgraph->nodes [i+1 ]->op == GGML_OP_VIEW &&
33593405 cgraph->nodes [i+2 ]->op == GGML_OP_GET_ROWS &&
33603406 cgraph->nodes [i+3 ]->op == GGML_OP_RESHAPE &&
33613407 cgraph->nodes [i+4 ]->op == GGML_OP_SOFT_MAX &&
3362- cgraph->nodes [i+5 ]->op == GGML_OP_RESHAPE) {
3408+ cgraph->nodes [i+5 ]->op == GGML_OP_RESHAPE && ops_are_same_device (cgraph, i, i+ 4 ) ) {
33633409 cuda_openai_experts (ctx, dst, cgraph->nodes [i+4 ]);
33643410 i += 5 ;
33653411 } else {
@@ -3390,6 +3436,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
33903436 printf (" %s(%s): %d us\n " , ggml_op_name (dst->op ), dst->name , (int )(tim2 - tim1));
33913437#endif
33923438
3439+ #undef ENABLE_FUSION
3440+
33933441 return true ;
33943442}
33953443
0 commit comments