@@ -2008,6 +2008,97 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
20082008 }
20092009}
20102010
2011+ static bool ggml_cuda_should_fuse_mul_mat (const ggml_tensor * ffn_up,
2012+ const ggml_tensor * ffn_gate,
2013+ const ggml_tensor * glu) {
2014+ bool is_mul_mat = ffn_up->op == GGML_OP_MUL_MAT && ffn_gate->op == GGML_OP_MUL_MAT && glu->op == GGML_OP_GLU;
2015+ bool is_mul_mat_id =
2016+ ffn_up->op == GGML_OP_MUL_MAT_ID && ffn_gate->op == GGML_OP_MUL_MAT_ID && glu->op == GGML_OP_GLU;
2017+
2018+ GGML_ASSERT (ffn_up && ffn_gate && glu);
2019+
2020+ if (!is_mul_mat && !is_mul_mat_id) {
2021+ return false ;
2022+ }
2023+
2024+ if (ffn_up->src [0 ]->type != ffn_gate->src [0 ]->type || !ggml_are_same_shape (ffn_up->src [0 ], ffn_gate->src [0 ]) ||
2025+ !ggml_are_same_stride (ffn_up->src [0 ], ffn_gate->src [0 ])) {
2026+ return false ;
2027+ }
2028+
2029+ if (ffn_up->src [1 ] != ffn_gate->src [1 ]) {
2030+ return false ;
2031+ }
2032+
2033+ if (ffn_up->src [2 ] && (ffn_up->src [2 ] != ffn_gate->src [2 ])) {
2034+ return false ;
2035+ }
2036+
2037+ if (glu->src [0 ] != ffn_gate && glu->src [1 ] != ffn_up) {
2038+ return false ;
2039+ }
2040+
2041+ static constexpr std::array<ggml_glu_op, 2 > valid_glu_ops = { GGML_GLU_OP_SWIGLU, GGML_GLU_OP_GEGLU };
2042+
2043+ if (std::find (valid_glu_ops.begin (), valid_glu_ops.end (), ggml_get_glu_op (glu)) == valid_glu_ops.end ()) {
2044+ return false ;
2045+ }
2046+
2047+ if (const bool swapped = ggml_get_op_params_i32 (glu, 1 ); swapped) {
2048+ return false ;
2049+ }
2050+
2051+ const bool split = ggml_backend_buft_is_cuda_split (ffn_up->src [0 ]->buffer ->buft ) ||
2052+ ggml_backend_buft_is_cuda_split (ffn_gate->src [0 ]->buffer ->buft );
2053+
2054+ // TODO: add support for fusion for split buffers
2055+ if (split) {
2056+ return false ;
2057+ }
2058+
2059+ return true ;
2060+ }
2061+
2062+ static bool ggml_cuda_should_fuse_mul_mat_vec_f (const ggml_tensor * tensor) {
2063+ ggml_tensor * src0 = tensor->src [0 ];
2064+ ggml_tensor * src1 = tensor->src [1 ];
2065+ const ggml_tensor * dst = tensor;
2066+
2067+ const bool is_mul_mat_id = tensor->op == GGML_OP_MUL_MAT_ID;
2068+
2069+ bool use_mul_mat_vec_f =
2070+ (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) &&
2071+ src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
2072+
2073+ const int cc = ggml_cuda_info ().devices [ggml_cuda_get_device ()].cc ;
2074+ use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf (src0->type , cc, src0->ne , is_mul_mat_id ? src1->ne [2 ]: src1->ne [1 ]);
2075+
2076+ if (tensor->op == GGML_OP_MUL_MAT_ID) {
2077+ use_mul_mat_vec_f = use_mul_mat_vec_f && dst->ne [2 ] == 1 ;
2078+ }
2079+
2080+ return use_mul_mat_vec_f;
2081+ }
2082+
2083+ static bool ggml_cuda_should_fuse_mul_mat_vec_q (const ggml_tensor * tensor) {
2084+ ggml_tensor * src0 = tensor->src [0 ];
2085+ ggml_tensor * src1 = tensor->src [1 ];
2086+ const ggml_tensor * dst = tensor;
2087+
2088+ const bool bad_padding_clear = ggml_backend_buffer_get_usage (src0->buffer ) == GGML_BACKEND_BUFFER_USAGE_COMPUTE &&
2089+ ggml_nbytes (src0) != ggml_backend_buffer_get_alloc_size (src0->buffer , src0) &&
2090+ src0->view_src ;
2091+
2092+ bool use_mul_mat_vec_q = ggml_is_quantized (src0->type ) && !bad_padding_clear && src1->type == GGML_TYPE_F32 &&
2093+ dst->type == GGML_TYPE_F32 && src1->ne [1 ] <= MMVQ_MAX_BATCH_SIZE;
2094+
2095+ if (tensor->op == GGML_OP_MUL_MAT_ID) {
2096+ use_mul_mat_vec_q = use_mul_mat_vec_q && dst->ne [2 ] == 1 ;
2097+ }
2098+
2099+ return use_mul_mat_vec_q;
2100+ }
2101+
20112102static void ggml_cuda_mul_mat (ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
20122103 const bool split = ggml_backend_buft_is_cuda_split (src0->buffer ->buft );
20132104
@@ -2758,7 +2849,7 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
27582849 }
27592850 }
27602851
2761- if (node->op == GGML_OP_SCALE &&
2852+ if (( node->op == GGML_OP_SCALE || node-> op == GGML_OP_GLU) &&
27622853 memcmp (graph_node_properties->op_params , node->op_params , GGML_MAX_OP_PARAMS) != 0 ) {
27632854 return false ;
27642855 }
@@ -2872,6 +2963,25 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
28722963 }
28732964 }
28742965
2966+ std::initializer_list<enum ggml_op> mul_mat_id_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU };
2967+
2968+ std::initializer_list<enum ggml_op> mul_mat_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT, GGML_OP_GLU };
2969+
2970+ if (ops.size () == 3 && (std::equal (ops.begin (), ops.end (), mul_mat_id_glu_ops.begin ()) ||
2971+ std::equal (ops.begin (), ops.end (), mul_mat_glu_ops.begin ()))) {
2972+ if (node_idx + 2 >= cgraph->n_nodes ) {
2973+ return false ;
2974+ }
2975+
2976+ const ggml_tensor * ffn_gate = cgraph->nodes [node_idx];
2977+ const ggml_tensor * ffn_up = cgraph->nodes [node_idx + 1 ];
2978+ const ggml_tensor * glu = cgraph->nodes [node_idx + 2 ];
2979+
2980+ if (ggml_cuda_should_fuse_mul_mat (ffn_up, ffn_gate, glu)) {
2981+ return true ;
2982+ }
2983+ }
2984+
28752985 if (!ggml_can_fuse (cgraph, node_idx, ops)) {
28762986 return false ;
28772987 }
@@ -3003,6 +3113,36 @@ static void evaluate_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cuda_
30033113 }
30043114 }
30053115
3116+ bool fused_mul_mat_vec = false ;
3117+
3118+ for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
3119+ if (ggml_cuda_can_fuse (cgraph, i, { op, op, GGML_OP_GLU }, {})) {
3120+ ggml_tensor * glu = cgraph->nodes [i + 2 ];
3121+ ggml_tensor * gate = glu->src [0 ];
3122+ ggml_tensor * up = glu->src [1 ];
3123+
3124+ const ggml_tensor * src0 = up->src [0 ];
3125+ const ggml_tensor * src1 = up->src [1 ];
3126+ const ggml_tensor * ids = up->src [2 ];
3127+
3128+ if (ggml_cuda_should_fuse_mul_mat_vec_f (up)) {
3129+ ggml_cuda_mul_mat_vec_f (*cuda_ctx, src0, src1, ids, glu, gate->src [0 ], glu);
3130+ fused_mul_mat_vec = true ;
3131+ break ;
3132+ }
3133+
3134+ if (ggml_cuda_should_fuse_mul_mat_vec_q (up)) {
3135+ ggml_cuda_mul_mat_vec_q (*cuda_ctx, src0, src1, ids, glu, gate->src [0 ], glu);
3136+ fused_mul_mat_vec = true ;
3137+ break ;
3138+ }
3139+ }
3140+ }
3141+
3142+ if (fused_mul_mat_vec) {
3143+ i += 2 ;
3144+ continue ;
3145+ }
30063146
30073147 if (ggml_cuda_can_fuse (cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) {
30083148 ggml_cuda_op_rms_norm_fused_add (*cuda_ctx, node, cgraph->nodes [i+1 ], cgraph->nodes [i+2 ]);
0 commit comments