@@ -35,7 +35,6 @@ bool g_mul_mat_q = true;
3535#include " ggml-cuda/rope.cuh"
3636#include " ggml-cuda/roll.cuh"
3737#include " ggml-cuda/scale.cuh"
38- #include " ggml-cuda/softcap.cuh"
3938#include " ggml-cuda/softmax.cuh"
4039#include " ggml-cuda/ssm-conv.cuh"
4140#include " ggml-cuda/ssm-scan.cuh"
@@ -2776,12 +2775,7 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
27762775}
27772776#endif
27782777
2779- static bool ggml_cuda_can_fuse (const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops, std::initializer_list<enum ggml_unary_op> unary_ops) {
2780- #ifndef NDEBUG
2781- const size_t num_unary = std::count (ops.begin (), ops.end (), GGML_OP_UNARY);
2782- GGML_ASSERT (unary_ops.size () == num_unary);
2783- #endif
2784-
2778+ static bool ggml_cuda_can_fuse (const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
27852779 if (!ggml_can_fuse (cgraph, node_idx, ops)) {
27862780 return false ;
27872781 }
@@ -2809,32 +2803,9 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
28092803 if (!ggml_is_contiguous_rows (mul->src [0 ]) || !ggml_is_contiguous_rows (mul->src [1 ])) {
28102804 return false ;
28112805 }
2812-
2813- return true ;
2814- }
2815-
2816- if (ops.size () == 3 && ops.begin ()[0 ] == GGML_OP_SCALE && ops.begin ()[1 ] == GGML_OP_UNARY && ops.begin ()[2 ] == GGML_OP_SCALE
2817- && unary_ops.size () == 1 && unary_ops.begin ()[0 ] == GGML_UNARY_OP_TANH) {
2818- const ggml_tensor *scale = cgraph->nodes [node_idx];
2819- const ggml_tensor *tanh = cgraph->nodes [node_idx+1 ];
2820- const ggml_tensor *scale2 = cgraph->nodes [node_idx+2 ];
2821-
2822- GGML_ASSERT (scale->src [0 ]->type == GGML_TYPE_F32);
2823- GGML_ASSERT (scale->type == GGML_TYPE_F32);
2824-
2825- if (ggml_get_unary_op (tanh) != GGML_UNARY_OP_TANH) {
2826- return false ;
2827- }
2828-
2829- // Check for bias
2830- if (ggml_get_op_params_f32 (scale, 1 ) != 0 .0f || ggml_get_op_params_f32 (scale2, 1 ) != 0 .0f ) {
2831- return false ;
2832- }
2833-
2834- return true ;
28352806 }
28362807
2837- return false ;
2808+ return true ;
28382809}
28392810
28402811static void evaluate_and_capture_cuda_graph (ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
@@ -2855,18 +2826,10 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
28552826 }
28562827
28572828 static bool disable_fusion = (getenv (" GGML_CUDA_DISABLE_FUSION" ) != nullptr );
2858- if (!disable_fusion) {
2859- if (ggml_cuda_can_fuse (cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) {
2860- ggml_cuda_op_rms_norm_fused (*cuda_ctx, node, cgraph->nodes [i+1 ]);
2861- i++;
2862- continue ;
2863- }
2864-
2865- if (ggml_cuda_can_fuse (cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
2866- i += 2 ;
2867- ggml_cuda_op_softcap (*cuda_ctx, cgraph->nodes [i], node);
2868- continue ;
2869- }
2829+ if (!disable_fusion && ggml_cuda_can_fuse (cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
2830+ ggml_cuda_op_rms_norm_fused (*cuda_ctx, node, cgraph->nodes [i+1 ]);
2831+ i++;
2832+ continue ;
28702833 }
28712834#ifndef NDEBUG
28722835 assert (node->buffer ->buft == ggml_backend_cuda_buffer_type (cuda_ctx->device ));
0 commit comments