Skip to content

Commit 42dee65

Browse files
committed
Merge remote-tracking branch 'amllama.cpp/cuda_fuse_gate' into esocrok
1 parent 47a22a6 commit 42dee65

File tree

8 files changed

+510
-190
lines changed

8 files changed

+510
-190
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 141 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2008,6 +2008,97 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
20082008
}
20092009
}
20102010

2011+
static bool ggml_cuda_should_fuse_mul_mat(const ggml_tensor * ffn_up,
2012+
const ggml_tensor * ffn_gate,
2013+
const ggml_tensor * glu) {
2014+
bool is_mul_mat = ffn_up->op == GGML_OP_MUL_MAT && ffn_gate->op == GGML_OP_MUL_MAT && glu->op == GGML_OP_GLU;
2015+
bool is_mul_mat_id =
2016+
ffn_up->op == GGML_OP_MUL_MAT_ID && ffn_gate->op == GGML_OP_MUL_MAT_ID && glu->op == GGML_OP_GLU;
2017+
2018+
GGML_ASSERT(ffn_up && ffn_gate && glu);
2019+
2020+
if (!is_mul_mat && !is_mul_mat_id) {
2021+
return false;
2022+
}
2023+
2024+
if (ffn_up->src[0]->type != ffn_gate->src[0]->type || !ggml_are_same_shape(ffn_up->src[0], ffn_gate->src[0]) ||
2025+
!ggml_are_same_stride(ffn_up->src[0], ffn_gate->src[0])) {
2026+
return false;
2027+
}
2028+
2029+
if (ffn_up->src[1] != ffn_gate->src[1]) {
2030+
return false;
2031+
}
2032+
2033+
if (ffn_up->src[2] && (ffn_up->src[2] != ffn_gate->src[2])) {
2034+
return false;
2035+
}
2036+
2037+
if (glu->src[0] != ffn_gate && glu->src[1] != ffn_up) {
2038+
return false;
2039+
}
2040+
2041+
static constexpr std::array<ggml_glu_op, 2> valid_glu_ops = { GGML_GLU_OP_SWIGLU, GGML_GLU_OP_GEGLU };
2042+
2043+
if (std::find(valid_glu_ops.begin(), valid_glu_ops.end(), ggml_get_glu_op(glu)) == valid_glu_ops.end()) {
2044+
return false;
2045+
}
2046+
2047+
if (const bool swapped = ggml_get_op_params_i32(glu, 1); swapped) {
2048+
return false;
2049+
}
2050+
2051+
const bool split = ggml_backend_buft_is_cuda_split(ffn_up->src[0]->buffer->buft) ||
2052+
ggml_backend_buft_is_cuda_split(ffn_gate->src[0]->buffer->buft);
2053+
2054+
//TODO: add support for fusion for split buffers
2055+
if (split) {
2056+
return false;
2057+
}
2058+
2059+
return true;
2060+
}
2061+
2062+
static bool ggml_cuda_should_fuse_mul_mat_vec_f(const ggml_tensor * tensor) {
2063+
ggml_tensor * src0 = tensor->src[0];
2064+
ggml_tensor * src1 = tensor->src[1];
2065+
const ggml_tensor * dst = tensor;
2066+
2067+
const bool is_mul_mat_id = tensor->op == GGML_OP_MUL_MAT_ID;
2068+
2069+
bool use_mul_mat_vec_f =
2070+
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) &&
2071+
src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
2072+
2073+
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
2074+
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, is_mul_mat_id ? src1->ne[2]: src1->ne[1]);
2075+
2076+
if (tensor->op == GGML_OP_MUL_MAT_ID) {
2077+
use_mul_mat_vec_f = use_mul_mat_vec_f && dst->ne[2] == 1;
2078+
}
2079+
2080+
return use_mul_mat_vec_f;
2081+
}
2082+
2083+
static bool ggml_cuda_should_fuse_mul_mat_vec_q(const ggml_tensor * tensor) {
2084+
ggml_tensor * src0 = tensor->src[0];
2085+
ggml_tensor * src1 = tensor->src[1];
2086+
const ggml_tensor * dst = tensor;
2087+
2088+
const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE &&
2089+
ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) &&
2090+
src0->view_src;
2091+
2092+
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 &&
2093+
dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
2094+
2095+
if (tensor->op == GGML_OP_MUL_MAT_ID) {
2096+
use_mul_mat_vec_q = use_mul_mat_vec_q && dst->ne[2] == 1;
2097+
}
2098+
2099+
return use_mul_mat_vec_q;
2100+
}
2101+
20112102
static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
20122103
const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
20132104

@@ -2758,7 +2849,7 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
27582849
}
27592850
}
27602851

2761-
if (node->op == GGML_OP_SCALE &&
2852+
if ((node->op == GGML_OP_SCALE || node->op == GGML_OP_GLU) &&
27622853
memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
27632854
return false;
27642855
}
@@ -2872,6 +2963,25 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
28722963
}
28732964
}
28742965

2966+
std::initializer_list<enum ggml_op> mul_mat_id_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU };
2967+
2968+
std::initializer_list<enum ggml_op> mul_mat_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT, GGML_OP_GLU };
2969+
2970+
if (ops.size() == 3 && (std::equal(ops.begin(), ops.end(), mul_mat_id_glu_ops.begin()) ||
2971+
std::equal(ops.begin(), ops.end(), mul_mat_glu_ops.begin()))) {
2972+
if (node_idx + 2 >= cgraph->n_nodes) {
2973+
return false;
2974+
}
2975+
2976+
const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
2977+
const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 1];
2978+
const ggml_tensor * glu = cgraph->nodes[node_idx + 2];
2979+
2980+
if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu)) {
2981+
return true;
2982+
}
2983+
}
2984+
28752985
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
28762986
return false;
28772987
}
@@ -3003,6 +3113,36 @@ static void evaluate_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cuda_
30033113
}
30043114
}
30053115

3116+
bool fused_mul_mat_vec = false;
3117+
3118+
for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
3119+
if (ggml_cuda_can_fuse(cgraph, i, { op, op, GGML_OP_GLU }, {})) {
3120+
ggml_tensor * glu = cgraph->nodes[i + 2];
3121+
ggml_tensor * gate = glu->src[0];
3122+
ggml_tensor * up = glu->src[1];
3123+
3124+
const ggml_tensor * src0 = up->src[0];
3125+
const ggml_tensor * src1 = up->src[1];
3126+
const ggml_tensor * ids = up->src[2];
3127+
3128+
if (ggml_cuda_should_fuse_mul_mat_vec_f(up)) {
3129+
ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, gate->src[0], glu);
3130+
fused_mul_mat_vec = true;
3131+
break;
3132+
}
3133+
3134+
if (ggml_cuda_should_fuse_mul_mat_vec_q(up)) {
3135+
ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, gate->src[0], glu);
3136+
fused_mul_mat_vec = true;
3137+
break;
3138+
}
3139+
}
3140+
}
3141+
3142+
if (fused_mul_mat_vec) {
3143+
i += 2;
3144+
continue;
3145+
}
30063146

30073147
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) {
30083148
ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);

0 commit comments

Comments
 (0)