@@ -2772,57 +2772,54 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
27722772 return false ;
27732773 }
27742774
2775- switch (ops.size ()) {
2776- case 2 :
2777- if (ops.begin ()[0 ] == GGML_OP_RMS_NORM && ops.begin ()[1 ] == GGML_OP_MUL) {
2778- const ggml_tensor *rms_norm = cgraph->nodes [node_idx];
2779- const ggml_tensor *mul = cgraph->nodes [node_idx+1 ];
2780-
2781- GGML_ASSERT (rms_norm->src [0 ]->type == GGML_TYPE_F32);
2782- GGML_ASSERT (rms_norm->type == GGML_TYPE_F32);
2783-
2784- // rms norm only supports F32
2785- if (mul->src [0 ]->type != GGML_TYPE_F32 ||
2786- mul->src [1 ]->type != GGML_TYPE_F32 ||
2787- mul->type != GGML_TYPE_F32) {
2788- return false ;
2789- }
2775+ if (ops.size () == 2 && ops.begin ()[0 ] == GGML_OP_RMS_NORM && ops.begin ()[1 ] == GGML_OP_MUL) {
2776+ const ggml_tensor *rms_norm = cgraph->nodes [node_idx];
2777+ const ggml_tensor *mul = cgraph->nodes [node_idx+1 ];
27902778
2791- // if rms norm is the B operand, then we don't handle broadcast
2792- if (rms_norm == mul->src [1 ] && !ggml_are_same_shape (mul->src [0 ], rms_norm->src [1 ])) {
2793- return false ;
2794- }
2779+ GGML_ASSERT (rms_norm->src [0 ]->type == GGML_TYPE_F32);
2780+ GGML_ASSERT (rms_norm->type == GGML_TYPE_F32);
27952781
2796- // rms_norm kernel assumes contigous rows
2797- if (!ggml_is_contiguous_rows (mul->src [0 ]) || !ggml_is_contiguous_rows (mul->src [1 ])) {
2798- return false ;
2799- }
2800- }
2801- break ;
2802- case 3 :
2803- if (ops.begin ()[0 ] == GGML_OP_SCALE && ops.begin ()[1 ] == GGML_OP_UNARY && ops.begin ()[2 ] == GGML_OP_SCALE
2804- && unary_ops.size () == 1 && unary_ops.begin ()[0 ] == GGML_UNARY_OP_TANH) {
2805- const ggml_tensor *scale = cgraph->nodes [node_idx];
2806- const ggml_tensor *tanh = cgraph->nodes [node_idx+1 ];
2807- const ggml_tensor *scale2 = cgraph->nodes [node_idx+2 ];
2782+ // rms norm only supports F32
2783+ if (mul->src [0 ]->type != GGML_TYPE_F32 ||
2784+ mul->src [1 ]->type != GGML_TYPE_F32 ||
2785+ mul->type != GGML_TYPE_F32) {
2786+ return false ;
2787+ }
28082788
2809- GGML_ASSERT (scale->src [0 ]->type == GGML_TYPE_F32);
2810- GGML_ASSERT (scale->type == GGML_TYPE_F32);
2789+ // if rms norm is the B operand, then we don't handle broadcast
2790+ if (rms_norm == mul->src [1 ] && !ggml_are_same_shape (mul->src [0 ], rms_norm->src [1 ])) {
2791+ return false ;
2792+ }
28112793
2812- if (tanh->src [0 ] != scale || scale2->src [0 ] != tanh) {
2813- return false ;
2814- }
2794+ // rms_norm kernel assumes contigous rows
2795+ if (!ggml_is_contiguous_rows (mul->src [0 ]) || !ggml_is_contiguous_rows (mul->src [1 ])) {
2796+ return false ;
2797+ }
28152798
2816- if (ggml_get_op_params_f32 (scale, 1 ) != 0 .0f || ggml_get_op_params_f32 (scale2, 1 ) != 0 .0f ) {
2817- return false ;
2818- }
2819- }
2820- break ;
2821- default :
2799+ return true ;
2800+ }
2801+
2802+ if (ops.size () == 3 && ops.begin ()[0 ] == GGML_OP_SCALE && ops.begin ()[1 ] == GGML_OP_UNARY && ops.begin ()[2 ] == GGML_OP_SCALE
2803+ && unary_ops.size () == 1 && unary_ops.begin ()[0 ] == GGML_UNARY_OP_TANH) {
2804+ const ggml_tensor *scale = cgraph->nodes [node_idx];
2805+ const ggml_tensor *tanh = cgraph->nodes [node_idx+1 ];
2806+ const ggml_tensor *scale2 = cgraph->nodes [node_idx+2 ];
2807+
2808+ GGML_ASSERT (scale->src [0 ]->type == GGML_TYPE_F32);
2809+ GGML_ASSERT (scale->type == GGML_TYPE_F32);
2810+
2811+ if (tanh->src [0 ] != scale || scale2->src [0 ] != tanh) {
28222812 return false ;
2813+ }
2814+
2815+ if (ggml_get_op_params_f32 (scale, 1 ) != 0 .0f || ggml_get_op_params_f32 (scale2, 1 ) != 0 .0f ) {
2816+ return false ;
2817+ }
2818+
2819+ return true ;
28232820 }
28242821
2825- return true ;
2822+ return false ;
28262823}
28272824
28282825static void evaluate_and_capture_cuda_graph (ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
0 commit comments