3232#include " ggml-cuda/quantize.cuh"
3333#include " ggml-cuda/rope.cuh"
3434#include " ggml-cuda/scale.cuh"
35+ #include " ggml-cuda/softcap.cuh"
3536#include " ggml-cuda/softmax.cuh"
3637#include " ggml-cuda/ssm-conv.cuh"
3738#include " ggml-cuda/ssm-scan.cuh"
@@ -2766,34 +2767,59 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
27662767}
27672768#endif
27682769
2769- static bool ggml_cuda_can_fuse (const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
2770+ static bool ggml_cuda_can_fuse (const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops, std::initializer_list< enum ggml_unary_op> unary_ops ) {
27702771 if (!ggml_can_fuse (cgraph, node_idx, ops)) {
27712772 return false ;
27722773 }
27732774
2774- if (ops.size () == 2 && ops.begin ()[0 ] == GGML_OP_RMS_NORM && ops.begin ()[1 ] == GGML_OP_MUL) {
2775- const ggml_tensor *rms_norm = cgraph->nodes [node_idx];
2776- const ggml_tensor *mul = cgraph->nodes [node_idx+1 ];
2775+ switch (ops.size ()) {
2776+ case 2 :
2777+ if (ops.begin ()[0 ] == GGML_OP_RMS_NORM && ops.begin ()[1 ] == GGML_OP_MUL) {
2778+ const ggml_tensor *rms_norm = cgraph->nodes [node_idx];
2779+ const ggml_tensor *mul = cgraph->nodes [node_idx+1 ];
27772780
2778- GGML_ASSERT (rms_norm->src [0 ]->type == GGML_TYPE_F32);
2779- GGML_ASSERT (rms_norm->type == GGML_TYPE_F32);
2781+ GGML_ASSERT (rms_norm->src [0 ]->type == GGML_TYPE_F32);
2782+ GGML_ASSERT (rms_norm->type == GGML_TYPE_F32);
27802783
2781- // rms norm only supports F32
2782- if (mul->src [0 ]->type != GGML_TYPE_F32 ||
2783- mul->src [1 ]->type != GGML_TYPE_F32 ||
2784- mul->type != GGML_TYPE_F32) {
2785- return false ;
2786- }
2784+ // rms norm only supports F32
2785+ if (mul->src [0 ]->type != GGML_TYPE_F32 ||
2786+ mul->src [1 ]->type != GGML_TYPE_F32 ||
2787+ mul->type != GGML_TYPE_F32) {
2788+ return false ;
2789+ }
27872790
2788- // if rms norm is the B operand, then we don't handle broadcast
2789- if (rms_norm == mul->src [1 ] && !ggml_are_same_shape (mul->src [0 ], rms_norm->src [1 ])) {
2790- return false ;
2791- }
2791+ // if rms norm is the B operand, then we don't handle broadcast
2792+ if (rms_norm == mul->src [1 ] && !ggml_are_same_shape (mul->src [0 ], rms_norm->src [1 ])) {
2793+ return false ;
2794+ }
2795+
2796+ // rms_norm kernel assumes contigous rows
2797+ if (!ggml_is_contiguous_rows (mul->src [0 ]) || !ggml_is_contiguous_rows (mul->src [1 ])) {
2798+ return false ;
2799+ }
2800+ }
2801+ break ;
2802+ case 3 :
2803+ if (ops.begin ()[0 ] == GGML_OP_SCALE && ops.begin ()[1 ] == GGML_OP_UNARY && ops.begin ()[2 ] == GGML_OP_SCALE
2804+ && unary_ops.size () == 1 && unary_ops.begin ()[0 ] == GGML_UNARY_OP_TANH) {
2805+ const ggml_tensor *scale = cgraph->nodes [node_idx];
2806+ const ggml_tensor *tanh = cgraph->nodes [node_idx+1 ];
2807+ const ggml_tensor *scale2 = cgraph->nodes [node_idx+2 ];
2808+
2809+ GGML_ASSERT (scale->src [0 ]->type == GGML_TYPE_F32);
2810+ GGML_ASSERT (scale->type == GGML_TYPE_F32);
2811+
2812+ if (tanh->src [0 ] != scale || scale2->src [0 ] != tanh) {
2813+ return false ;
2814+ }
27922815
2793- // rms_norm kernel assumes contigous rows
2794- if (!ggml_is_contiguous_rows (mul->src [0 ]) || !ggml_is_contiguous_rows (mul->src [1 ])) {
2816+ if (ggml_get_op_params_f32 (scale, 1 ) != 0 .0f || ggml_get_op_params_f32 (scale2, 1 ) != 0 .0f ) {
2817+ return false ;
2818+ }
2819+ }
2820+ break ;
2821+ default :
27952822 return false ;
2796- }
27972823 }
27982824
27992825 return true ;
@@ -2817,10 +2843,27 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
28172843 }
28182844
28192845 static bool disable_fusion = (getenv (" GGML_CUDA_DISABLE_FUSION" ) != nullptr );
2820- if (!disable_fusion && ggml_cuda_can_fuse (cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
2821- ggml_cuda_op_rms_norm_fused (*cuda_ctx, node, cgraph->nodes [i+1 ]);
2822- i++;
2823- continue ;
2846+ if (!disable_fusion) {
2847+ if (ggml_cuda_can_fuse (cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) {
2848+ ggml_cuda_op_rms_norm_fused (*cuda_ctx, node, cgraph->nodes [i+1 ]);
2849+ i++;
2850+ continue ;
2851+ }
2852+
2853+ if (ggml_cuda_can_fuse (cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
2854+ ggml_tensor * src0 = node->src [0 ];
2855+ float scale = ggml_get_op_params_f32 (node, 0 );
2856+
2857+ i += 2 ; node = cgraph->nodes [i];
2858+ float softcap = ggml_get_op_params_f32 (node, 0 );
2859+
2860+ ggml_set_op_params_f32 (node, 0 , scale);
2861+ ggml_set_op_params_f32 (node, 1 , softcap);
2862+ node->src [0 ] = src0;
2863+
2864+ ggml_cuda_op_softcap (*cuda_ctx, node);
2865+ continue ;
2866+ }
28242867 }
28252868#ifndef NDEBUG
28262869 assert (node->buffer ->buft == ggml_backend_cuda_buffer_type (cuda_ctx->device ));
0 commit comments