@@ -3062,7 +3062,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
30623062
30633063 auto next = i < cgraph->n_nodes - 1 ? cgraph->nodes [i+1 ] : nullptr ;
30643064
3065- // printf("%4d %s(%s)\n", i, ggml_op_name(dst->op), dst->name);
30663065 switch (dst->op ) {
30673066 case GGML_OP_ARGMAX:
30683067 ggml_cuda_argmax (ctx, dst);
@@ -3097,6 +3096,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
30973096 ggml_are_same_shape (dst, cgraph->nodes [i+1 ]->src [1 ]) &&
30983097 cgraph->nodes [i+1 ] == cgraph->nodes [i+2 ]->src [0 ] &&
30993098 ops_are_same_device (cgraph, i, i+2 )) {
3099+ // printf("Fusing add->add->fused_rms of %s, %s, %s\n", dst->name, cgraph->nodes[i+1]->name, cgraph->nodes[i+2]->name);
31003100 ggml_cuda_op_fused_add_add_rms_norm (ctx, dst, cgraph->nodes [i+1 ], cgraph->nodes [i+2 ]);
31013101 i += 2 ;
31023102 }
@@ -3244,27 +3244,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
32443244 ggml_cuda_op_rms_norm (ctx, dst);
32453245 break ;
32463246 case GGML_OP_FUSED_RMS_NORM:
3247- // if (i + 6 < cgraph->n_nodes) {
3248- // printf("=== Fused rms_norm(%s)\n", dst->name);
3249- // for (int j = 1; j <= 6; ++j) printf(" %s(%s)\n", ggml_op_name(cgraph->nodes[i+j]->op), cgraph->nodes[i+j]->name);
3250- // }
3251- if (ENABLE_FUSION && i + 4 < cgraph->n_nodes &&
3252- cgraph->nodes [i+1 ]->op == GGML_OP_VIEW &&
3253- cgraph->nodes [i+2 ]->op == GGML_OP_FUSED_RMS_NORM &&
3254- cgraph->nodes [i+3 ]->op == GGML_OP_ROPE_FAST &&
3255- cgraph->nodes [i+4 ]->op == GGML_OP_ROPE_FAST &&
3256- ggml_cuda_op_fused_rms_rope_fast (ctx, cgraph->nodes [i+3 ], cgraph->nodes [i+4 ])) {
3257- i += 4 ;
3258- }
3259- else if (ENABLE_FUSION && i + 4 < cgraph->n_nodes &&
3260- cgraph->nodes [i+1 ]->op == GGML_OP_ROPE_FAST &&
3261- cgraph->nodes [i+2 ]->op == GGML_OP_RESHAPE &&
3262- cgraph->nodes [i+3 ]->op == GGML_OP_FUSED_RMS_NORM &&
3263- cgraph->nodes [i+4 ]->op == GGML_OP_ROPE_FAST &&
3264- ggml_cuda_op_fused_rms_rope_fast (ctx, cgraph->nodes [i+1 ], cgraph->nodes [i+4 ])) {
3265- i += 4 ;
3266- }
3267- else if (ENABLE_FUSION && i + 2 < cgraph->n_nodes &&
3247+ if (i + 2 < cgraph->n_nodes &&
32683248 cgraph->nodes [i+1 ]->op == GGML_OP_VIEW &&
32693249 cgraph->nodes [i+2 ]->op == GGML_OP_FUSED_RMS_NORM &&
32703250 dst->ne [2 ] == 1 && cgraph->nodes [i+2 ]->ne [2 ] == 1 ) {
@@ -3338,32 +3318,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
33383318 case GGML_OP_ROPE_BACK:
33393319 ggml_cuda_op_rope_back (ctx, dst);
33403320 break ;
3341- case GGML_OP_ROPE_FAST:
3342- if (ENABLE_FUSION && i + 3 < cgraph->n_nodes &&
3343- (cgraph->nodes [i+1 ]->op == GGML_OP_RESHAPE || cgraph->nodes [i+1 ]->op == GGML_OP_VIEW) &&
3344- (cgraph->nodes [i+2 ]->op == GGML_OP_RESHAPE || cgraph->nodes [i+2 ]->op == GGML_OP_VIEW) &&
3345- cgraph->nodes [i+3 ]->op == GGML_OP_ROPE_FAST &&
3346- ggml_cuda_op_fused_rope_fast (ctx, dst, cgraph->nodes [i+3 ])) {
3347- i += 3 ;
3348- }
3349- else if (ENABLE_FUSION && i + 2 < cgraph->n_nodes &&
3350- (cgraph->nodes [i+1 ]->op == GGML_OP_RESHAPE || cgraph->nodes [i+1 ]->op == GGML_OP_VIEW) &&
3351- cgraph->nodes [i+2 ]->op == GGML_OP_ROPE_FAST &&
3352- ggml_cuda_op_fused_rope_fast (ctx, dst, cgraph->nodes [i+2 ])) {
3353- i += 2 ;
3354- }
3355- else if (ENABLE_FUSION && i + 1 < cgraph->n_nodes &&
3356- cgraph->nodes [i+1 ]->op == GGML_OP_ROPE_FAST &&
3357- ggml_cuda_op_fused_rope_fast (ctx, dst, cgraph->nodes [i+1 ])) {
3358- i += 1 ;
3359- }
3360- else {
3361- ggml_cuda_op_rope_fast (ctx, dst);
3362- }
3363- break ;
3364- case GGML_OP_ROPE_CACHE:
3365- ggml_cuda_op_rope_cache (ctx, dst);
3366- break ;
33673321 case GGML_OP_IM2COL:
33683322 ggml_cuda_op_im2col (ctx, dst);
33693323 break ;
@@ -4423,8 +4377,6 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
44234377 case GGML_OP_SOFT_CAP_MAX:
44244378 case GGML_OP_ROPE:
44254379 case GGML_OP_ROPE_BACK:
4426- case GGML_OP_ROPE_FAST:
4427- case GGML_OP_ROPE_CACHE:
44284380 return true ;
44294381 // case GGML_OP_ROPE:
44304382 // return ggml_is_contiguous(op->src[0]);
0 commit comments