3333#include " ggml-cuda/rope.cuh"
3434#include " ggml-cuda/roll.cuh"
3535#include " ggml-cuda/scale.cuh"
36+ #include " ggml-cuda/softcap.cuh"
3637#include " ggml-cuda/softmax.cuh"
3738#include " ggml-cuda/ssm-conv.cuh"
3839#include " ggml-cuda/ssm-scan.cuh"
@@ -2770,7 +2771,12 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
27702771}
27712772#endif
27722773
2773- static bool ggml_cuda_can_fuse (const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
2774+ static bool ggml_cuda_can_fuse (const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops, std::initializer_list<enum ggml_unary_op> unary_ops) {
2775+ #ifndef NDEBUG
2776+ const size_t num_unary = std::count (ops.begin (), ops.end (), GGML_OP_UNARY);
2777+ GGML_ASSERT (unary_ops.size () == num_unary);
2778+ #endif
2779+
27742780 if (!ggml_can_fuse (cgraph, node_idx, ops)) {
27752781 return false ;
27762782 }
@@ -2798,9 +2804,32 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
27982804 if (!ggml_is_contiguous_rows (mul->src [0 ]) || !ggml_is_contiguous_rows (mul->src [1 ])) {
27992805 return false ;
28002806 }
2807+
2808+ return true ;
28012809 }
28022810
2803- return true ;
2811+ if (ops.size () == 3 && ops.begin ()[0 ] == GGML_OP_SCALE && ops.begin ()[1 ] == GGML_OP_UNARY && ops.begin ()[2 ] == GGML_OP_SCALE
2812+ && unary_ops.size () == 1 && unary_ops.begin ()[0 ] == GGML_UNARY_OP_TANH) {
2813+ const ggml_tensor *scale = cgraph->nodes [node_idx];
2814+ const ggml_tensor *tanh = cgraph->nodes [node_idx+1 ];
2815+ const ggml_tensor *scale2 = cgraph->nodes [node_idx+2 ];
2816+
2817+ GGML_ASSERT (scale->src [0 ]->type == GGML_TYPE_F32);
2818+ GGML_ASSERT (scale->type == GGML_TYPE_F32);
2819+
2820+ if (ggml_get_unary_op (tanh) != GGML_UNARY_OP_TANH) {
2821+ return false ;
2822+ }
2823+
2824+ // Check for bias
2825+ if (ggml_get_op_params_f32 (scale, 1 ) != 0 .0f || ggml_get_op_params_f32 (scale2, 1 ) != 0 .0f ) {
2826+ return false ;
2827+ }
2828+
2829+ return true ;
2830+ }
2831+
2832+ return false ;
28042833}
28052834
28062835static void evaluate_and_capture_cuda_graph (ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
@@ -2821,10 +2850,18 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
28212850 }
28222851
28232852 static bool disable_fusion = (getenv (" GGML_CUDA_DISABLE_FUSION" ) != nullptr );
2824- if (!disable_fusion && ggml_cuda_can_fuse (cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
2825- ggml_cuda_op_rms_norm_fused (*cuda_ctx, node, cgraph->nodes [i+1 ]);
2826- i++;
2827- continue ;
2853+ if (!disable_fusion) {
2854+ if (ggml_cuda_can_fuse (cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) {
2855+ ggml_cuda_op_rms_norm_fused (*cuda_ctx, node, cgraph->nodes [i+1 ]);
2856+ i++;
2857+ continue ;
2858+ }
2859+
2860+ if (ggml_cuda_can_fuse (cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
2861+ i += 2 ;
2862+ ggml_cuda_op_softcap (*cuda_ctx, cgraph->nodes [i], node);
2863+ continue ;
2864+ }
28282865 }
28292866#ifndef NDEBUG
28302867 assert (node->buffer ->buft == ggml_backend_cuda_buffer_type (cuda_ctx->device ));
0 commit comments