@@ -37,6 +37,7 @@ bool g_mul_mat_q = true;
3737#include " ggml-cuda/rope.cuh"
3838#include " ggml-cuda/roll.cuh"
3939#include " ggml-cuda/scale.cuh"
40+ #include " ggml-cuda/softcap.cuh"
4041#include " ggml-cuda/softmax.cuh"
4142#include " ggml-cuda/ssm-conv.cuh"
4243#include " ggml-cuda/ssm-scan.cuh"
@@ -3029,7 +3030,12 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
30293030}
30303031#endif
30313032
3032- static bool ggml_cuda_can_fuse (const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
3033+ static bool ggml_cuda_can_fuse (const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops, std::initializer_list<enum ggml_unary_op> unary_ops) {
3034+ #ifndef NDEBUG
3035+ const size_t num_unary = std::count (ops.begin (), ops.end (), GGML_OP_UNARY);
3036+ GGML_ASSERT (unary_ops.size () == num_unary);
3037+ #endif
3038+
30333039 if (!ggml_can_fuse (cgraph, node_idx, ops)) {
30343040 return false ;
30353041 }
@@ -3057,9 +3063,32 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
30573063 if (!ggml_is_contiguous_rows (mul->src [0 ]) || !ggml_is_contiguous_rows (mul->src [1 ])) {
30583064 return false ;
30593065 }
3066+
3067+ return true ;
30603068 }
30613069
3062- return true ;
3070+ if (ops.size () == 3 && ops.begin ()[0 ] == GGML_OP_SCALE && ops.begin ()[1 ] == GGML_OP_UNARY && ops.begin ()[2 ] == GGML_OP_SCALE
3071+ && unary_ops.size () == 1 && unary_ops.begin ()[0 ] == GGML_UNARY_OP_TANH) {
3072+ const ggml_tensor *scale = cgraph->nodes [node_idx];
3073+ const ggml_tensor *tanh = cgraph->nodes [node_idx+1 ];
3074+ const ggml_tensor *scale2 = cgraph->nodes [node_idx+2 ];
3075+
3076+ GGML_ASSERT (scale->src [0 ]->type == GGML_TYPE_F32);
3077+ GGML_ASSERT (scale->type == GGML_TYPE_F32);
3078+
3079+ if (ggml_get_unary_op (tanh) != GGML_UNARY_OP_TANH) {
3080+ return false ;
3081+ }
3082+
3083+ // Check for bias
3084+ if (ggml_get_op_params_f32 (scale, 1 ) != 0 .0f || ggml_get_op_params_f32 (scale2, 1 ) != 0 .0f ) {
3085+ return false ;
3086+ }
3087+
3088+ return true ;
3089+ }
3090+
3091+ return false ;
30633092}
30643093
30653094static void evaluate_and_capture_cuda_graph (ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
@@ -3080,10 +3109,18 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
30803109 }
30813110
30823111 static bool disable_fusion = (getenv (" GGML_CUDA_DISABLE_FUSION" ) != nullptr );
3083- if (!disable_fusion && ggml_cuda_can_fuse (cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
3084- ggml_cuda_op_rms_norm_fused (*cuda_ctx, node, cgraph->nodes [i+1 ]);
3085- i++;
3086- continue ;
3112+ if (!disable_fusion) {
3113+ if (ggml_cuda_can_fuse (cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) {
3114+ ggml_cuda_op_rms_norm_fused (*cuda_ctx, node, cgraph->nodes [i+1 ]);
3115+ i++;
3116+ continue ;
3117+ }
3118+
3119+ if (ggml_cuda_can_fuse (cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
3120+ i += 2 ;
3121+ ggml_cuda_op_softcap (*cuda_ctx, cgraph->nodes [i], node);
3122+ continue ;
3123+ }
30873124 }
30883125#ifndef NDEBUG
30893126 assert (node->buffer ->buft == ggml_backend_cuda_buffer_type (cuda_ctx->device ));
0 commit comments