@@ -3054,6 +3054,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
30543054 ggml_cuda_set_peer_access (dst->src [1 ]->ne [1 ], ctx.device );
30553055 }
30563056
3057+ #define ENABLE_FUSION true
3058+
30573059#if IK_PRINT_TIMING
30583060 int64_t tim1 = ggml_time_us ();
30593061#endif
@@ -3084,7 +3086,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
30843086 ggml_cuda_dup (ctx, dst);
30853087 break ;
30863088 case GGML_OP_ADD:
3087- if (GGML_CUDA_FUSION && i + 2 < cgraph->n_nodes &&
3089+ if (ENABLE_FUSION && i + 2 < cgraph->n_nodes &&
30883090 cgraph->nodes [i+1 ]->op == GGML_OP_ADD &&
30893091 cgraph->nodes [i+2 ]->op == GGML_OP_FUSED_RMS_NORM &&
30903092 ggml_is_contiguous (dst->src [0 ]) &&
@@ -3098,7 +3100,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
30983100 ggml_cuda_op_fused_add_add_rms_norm (ctx, dst, cgraph->nodes [i+1 ], cgraph->nodes [i+2 ]);
30993101 i += 2 ;
31003102 }
3101- else if (GGML_CUDA_FUSION && i + 1 < cgraph->n_nodes &&
3103+ else if (ENABLE_FUSION && i + 1 < cgraph->n_nodes &&
31023104 cgraph->nodes [i+1 ]->op == GGML_OP_FUSED_RMS_NORM &&
31033105 ggml_is_contiguous (dst->src [0 ]) &&
31043106 ggml_is_contiguous (dst->src [1 ]) &&
@@ -3155,7 +3157,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
31553157 ggml_cuda_op_relu (ctx, dst);
31563158 break ;
31573159 case GGML_UNARY_OP_SIGMOID:
3158- if (GGML_CUDA_FUSION && i + 5 < cgraph->n_nodes &&
3160+ if (ENABLE_FUSION && i + 5 < cgraph->n_nodes &&
31593161 cgraph->nodes [i+1 ]->op == GGML_OP_RESHAPE &&
31603162 cgraph->nodes [i+2 ]->op == GGML_OP_ADD &&
31613163 cgraph->nodes [i+3 ]->op == GGML_OP_ARGSORT &&
@@ -3164,14 +3166,14 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
31643166 cuda_glm45moe_experts (ctx, cgraph->nodes [i+5 ], cgraph->nodes [i+4 ]);
31653167 i += 5 ;
31663168 }
3167- else if (GGML_CUDA_FUSION && i + 4 < cgraph->n_nodes &&
3169+ else if (ENABLE_FUSION && i + 4 < cgraph->n_nodes &&
31683170 cgraph->nodes [i+1 ]->op == GGML_OP_RESHAPE &&
31693171 cgraph->nodes [i+2 ]->op == GGML_OP_ADD &&
31703172 cgraph->nodes [i+3 ]->op == GGML_OP_GROUPED_TOPK &&
31713173 cgraph->nodes [i+4 ]->op == GGML_OP_GET_ROWS && ops_are_same_device (cgraph, i, i+4 )) {
31723174 cuda_bailingmoev2_experts (ctx, cgraph->nodes [i+4 ], cgraph->nodes [i+3 ]);
31733175 i += 4 ;
3174- } else if (GGML_CUDA_FUSION && i + 2 < cgraph->n_nodes &&
3176+ } else if (ENABLE_FUSION && i + 2 < cgraph->n_nodes &&
31753177 cgraph->nodes [i+1 ]->op == GGML_OP_RESHAPE &&
31763178 cgraph->nodes [i+2 ]->op == GGML_OP_ADD && ops_are_same_device (cgraph, i, i+2 )) {
31773179 ggml_cuda_op_biased_sigmoid (ctx, cgraph->nodes [i+2 ]);
@@ -3242,23 +3244,23 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
32423244 ggml_cuda_op_rms_norm (ctx, dst);
32433245 break ;
32443246 case GGML_OP_FUSED_RMS_NORM:
3245- if (false && GGML_CUDA_FUSION && i + 4 < cgraph->n_nodes &&
3247+ if (false && ENABLE_FUSION && i + 4 < cgraph->n_nodes &&
32463248 cgraph->nodes [i+1 ]->op == GGML_OP_VIEW &&
32473249 cgraph->nodes [i+2 ]->op == GGML_OP_FUSED_RMS_NORM &&
32483250 cgraph->nodes [i+3 ]->op == GGML_OP_ROPE_FAST &&
32493251 cgraph->nodes [i+4 ]->op == GGML_OP_ROPE_FAST &&
32503252 ggml_cuda_op_fused_rms_rope_fast (ctx, cgraph->nodes [i+3 ], cgraph->nodes [i+4 ])) {
32513253 i += 4 ;
32523254 }
3253- else if (false && GGML_CUDA_FUSION && i + 4 < cgraph->n_nodes &&
3255+ else if (false && ENABLE_FUSION && i + 4 < cgraph->n_nodes &&
32543256 cgraph->nodes [i+1 ]->op == GGML_OP_ROPE_FAST &&
32553257 cgraph->nodes [i+2 ]->op == GGML_OP_RESHAPE &&
32563258 cgraph->nodes [i+3 ]->op == GGML_OP_FUSED_RMS_NORM &&
32573259 cgraph->nodes [i+4 ]->op == GGML_OP_ROPE_FAST &&
32583260 ggml_cuda_op_fused_rms_rope_fast (ctx, cgraph->nodes [i+1 ], cgraph->nodes [i+4 ])) {
32593261 i += 4 ;
32603262 }
3261- else if (GGML_CUDA_FUSION && i + 2 < cgraph->n_nodes &&
3263+ else if (ENABLE_FUSION && i + 2 < cgraph->n_nodes &&
32623264 cgraph->nodes [i+1 ]->op == GGML_OP_VIEW &&
32633265 cgraph->nodes [i+2 ]->op == GGML_OP_FUSED_RMS_NORM &&
32643266 dst->ne [2 ] == 1 && cgraph->nodes [i+2 ]->ne [2 ] == 1 ) {
@@ -3310,7 +3312,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
33103312 ggml_cuda_op_diag_mask_inf (ctx, dst);
33113313 break ;
33123314 case GGML_OP_SOFT_MAX:
3313- if (GGML_CUDA_FUSION && i + 4 < cgraph->n_nodes &&
3315+ if (ENABLE_FUSION && i + 4 < cgraph->n_nodes &&
33143316 cgraph->nodes [i+1 ]->op == GGML_OP_RESHAPE &&
33153317 cgraph->nodes [i+2 ]->op == GGML_OP_ARGSORT &&
33163318 cgraph->nodes [i+3 ]->op == GGML_OP_VIEW &&
@@ -3333,20 +3335,20 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
33333335 ggml_cuda_op_rope_back (ctx, dst);
33343336 break ;
33353337 case GGML_OP_ROPE_FAST:
3336- if (GGML_CUDA_FUSION && i + 3 < cgraph->n_nodes &&
3338+ if (ENABLE_FUSION && i + 3 < cgraph->n_nodes &&
33373339 (cgraph->nodes [i+1 ]->op == GGML_OP_RESHAPE || cgraph->nodes [i+1 ]->op == GGML_OP_VIEW) &&
33383340 (cgraph->nodes [i+2 ]->op == GGML_OP_RESHAPE || cgraph->nodes [i+2 ]->op == GGML_OP_VIEW) &&
33393341 cgraph->nodes [i+3 ]->op == GGML_OP_ROPE_FAST &&
33403342 ggml_cuda_op_fused_rope_fast (ctx, dst, cgraph->nodes [i+3 ])) {
33413343 i += 3 ;
33423344 }
3343- else if (GGML_CUDA_FUSION && i + 2 < cgraph->n_nodes &&
3345+ else if (ENABLE_FUSION && i + 2 < cgraph->n_nodes &&
33443346 (cgraph->nodes [i+1 ]->op == GGML_OP_RESHAPE || cgraph->nodes [i+1 ]->op == GGML_OP_VIEW) &&
33453347 cgraph->nodes [i+2 ]->op == GGML_OP_ROPE_FAST &&
33463348 ggml_cuda_op_fused_rope_fast (ctx, dst, cgraph->nodes [i+2 ])) {
33473349 i += 2 ;
33483350 }
3349- else if (GGML_CUDA_FUSION && i + 1 < cgraph->n_nodes &&
3351+ else if (ENABLE_FUSION && i + 1 < cgraph->n_nodes &&
33503352 cgraph->nodes [i+1 ]->op == GGML_OP_ROPE_FAST &&
33513353 ggml_cuda_op_fused_rope_fast (ctx, dst, cgraph->nodes [i+1 ])) {
33523354 i += 1 ;
@@ -3374,7 +3376,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
33743376 ggml_cuda_op_pool2d (ctx, dst);
33753377 break ;
33763378 case GGML_OP_SUM_ROWS:
3377- if (GGML_CUDA_FUSION && i + 2 < cgraph->n_nodes &&
3379+ if (ENABLE_FUSION && i + 2 < cgraph->n_nodes &&
33783380 cgraph->nodes [i+1 ]->op == GGML_OP_SCALE &&
33793381 cgraph->nodes [i+2 ]->op == GGML_OP_DIV &&
33803382 cgraph->nodes [i+1 ]->src [0 ] == dst &&
@@ -3383,7 +3385,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
33833385 ggml_cuda_op_sum_rows_div (ctx, cgraph->nodes [i+2 ]);
33843386 i += 2 ;
33853387 }
3386- else if (GGML_CUDA_FUSION && i + 1 < cgraph->n_nodes &&
3388+ else if (ENABLE_FUSION && i + 1 < cgraph->n_nodes &&
33873389 cgraph->nodes [i+1 ]->op == GGML_OP_DIV &&
33883390 cgraph->nodes [i+1 ]->src [1 ] == dst &&
33893391 cgraph->nodes [i+1 ]->src [0 ] == dst->src [0 ] && ops_are_same_device (cgraph, i, i+1 )) {
@@ -3394,7 +3396,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
33943396 }
33953397 break ;
33963398 case GGML_OP_ARGSORT:
3397- if (GGML_CUDA_FUSION && i + 5 < cgraph->n_nodes &&
3399+ if (ENABLE_FUSION && i + 5 < cgraph->n_nodes &&
33983400 cgraph->nodes [i+1 ]->op == GGML_OP_VIEW &&
33993401 cgraph->nodes [i+2 ]->op == GGML_OP_GET_ROWS &&
34003402 cgraph->nodes [i+3 ]->op == GGML_OP_RESHAPE &&
@@ -3430,6 +3432,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
34303432 printf (" %s(%s): %d us\n " , ggml_op_name (dst->op ), dst->name , (int )(tim2 - tim1));
34313433#endif
34323434
3435+ #undef ENABLE_FUSION
3436+
34333437 return true ;
34343438}
34353439
0 commit comments