33
33
#include " ggml-cuda/rope.cuh"
34
34
#include " ggml-cuda/roll.cuh"
35
35
#include " ggml-cuda/scale.cuh"
36
+ #include " ggml-cuda/softcap.cuh"
36
37
#include " ggml-cuda/softmax.cuh"
37
38
#include " ggml-cuda/ssm-conv.cuh"
38
39
#include " ggml-cuda/ssm-scan.cuh"
@@ -2770,7 +2771,12 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
2770
2771
}
2771
2772
#endif
2772
2773
2773
- static bool ggml_cuda_can_fuse (const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
2774
+ static bool ggml_cuda_can_fuse (const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops, std::initializer_list<enum ggml_unary_op> unary_ops) {
2775
+ #ifndef NDEBUG
2776
+ const size_t num_unary = std::count (ops.begin (), ops.end (), GGML_OP_UNARY);
2777
+ GGML_ASSERT (unary_ops.size () == num_unary);
2778
+ #endif
2779
+
2774
2780
if (!ggml_can_fuse (cgraph, node_idx, ops)) {
2775
2781
return false ;
2776
2782
}
@@ -2798,9 +2804,32 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
2798
2804
if (!ggml_is_contiguous_rows (mul->src [0 ]) || !ggml_is_contiguous_rows (mul->src [1 ])) {
2799
2805
return false ;
2800
2806
}
2807
+
2808
+ return true ;
2801
2809
}
2802
2810
2803
- return true ;
2811
+ if (ops.size () == 3 && ops.begin ()[0 ] == GGML_OP_SCALE && ops.begin ()[1 ] == GGML_OP_UNARY && ops.begin ()[2 ] == GGML_OP_SCALE
2812
+ && unary_ops.size () == 1 && unary_ops.begin ()[0 ] == GGML_UNARY_OP_TANH) {
2813
+ const ggml_tensor *scale = cgraph->nodes [node_idx];
2814
+ const ggml_tensor *tanh = cgraph->nodes [node_idx+1 ];
2815
+ const ggml_tensor *scale2 = cgraph->nodes [node_idx+2 ];
2816
+
2817
+ GGML_ASSERT (scale->src [0 ]->type == GGML_TYPE_F32);
2818
+ GGML_ASSERT (scale->type == GGML_TYPE_F32);
2819
+
2820
+ if (ggml_get_unary_op (tanh) != GGML_UNARY_OP_TANH) {
2821
+ return false ;
2822
+ }
2823
+
2824
+ // Check for bias
2825
+ if (ggml_get_op_params_f32 (scale, 1 ) != 0 .0f || ggml_get_op_params_f32 (scale2, 1 ) != 0 .0f ) {
2826
+ return false ;
2827
+ }
2828
+
2829
+ return true ;
2830
+ }
2831
+
2832
+ return false ;
2804
2833
}
2805
2834
2806
2835
static void evaluate_and_capture_cuda_graph (ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
@@ -2821,10 +2850,18 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
2821
2850
}
2822
2851
2823
2852
static bool disable_fusion = (getenv (" GGML_CUDA_DISABLE_FUSION" ) != nullptr );
2824
- if (!disable_fusion && ggml_cuda_can_fuse (cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
2825
- ggml_cuda_op_rms_norm_fused (*cuda_ctx, node, cgraph->nodes [i+1 ]);
2826
- i++;
2827
- continue ;
2853
+ if (!disable_fusion) {
2854
+ if (ggml_cuda_can_fuse (cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) {
2855
+ ggml_cuda_op_rms_norm_fused (*cuda_ctx, node, cgraph->nodes [i+1 ]);
2856
+ i++;
2857
+ continue ;
2858
+ }
2859
+
2860
+ if (ggml_cuda_can_fuse (cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
2861
+ i += 2 ;
2862
+ ggml_cuda_op_softcap (*cuda_ctx, cgraph->nodes [i], node);
2863
+ continue ;
2864
+ }
2828
2865
}
2829
2866
#ifndef NDEBUG
2830
2867
assert (node->buffer ->buft == ggml_backend_cuda_buffer_type (cuda_ctx->device ));
0 commit comments