Skip to content

Commit 9469456

Browse files
committed
Merge remote-tracking branch 'amllama.cpp/cuda_fuse_gate' into esocrok
1 parent 3203a80 commit 9469456

File tree

8 files changed

+510
-190
lines changed

8 files changed

+510
-190
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 141 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2008,6 +2008,97 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
20082008
}
20092009
}
20102010

2011+
static bool ggml_cuda_should_fuse_mul_mat(const ggml_tensor * ffn_up,
2012+
const ggml_tensor * ffn_gate,
2013+
const ggml_tensor * glu) {
2014+
bool is_mul_mat = ffn_up->op == GGML_OP_MUL_MAT && ffn_gate->op == GGML_OP_MUL_MAT && glu->op == GGML_OP_GLU;
2015+
bool is_mul_mat_id =
2016+
ffn_up->op == GGML_OP_MUL_MAT_ID && ffn_gate->op == GGML_OP_MUL_MAT_ID && glu->op == GGML_OP_GLU;
2017+
2018+
GGML_ASSERT(ffn_up && ffn_gate && glu);
2019+
2020+
if (!is_mul_mat && !is_mul_mat_id) {
2021+
return false;
2022+
}
2023+
2024+
if (ffn_up->src[0]->type != ffn_gate->src[0]->type || !ggml_are_same_shape(ffn_up->src[0], ffn_gate->src[0]) ||
2025+
!ggml_are_same_stride(ffn_up->src[0], ffn_gate->src[0])) {
2026+
return false;
2027+
}
2028+
2029+
if (ffn_up->src[1] != ffn_gate->src[1]) {
2030+
return false;
2031+
}
2032+
2033+
if (ffn_up->src[2] && (ffn_up->src[2] != ffn_gate->src[2])) {
2034+
return false;
2035+
}
2036+
2037+
if (glu->src[0] != ffn_gate && glu->src[1] != ffn_up) {
2038+
return false;
2039+
}
2040+
2041+
static constexpr std::array<ggml_glu_op, 2> valid_glu_ops = { GGML_GLU_OP_SWIGLU, GGML_GLU_OP_GEGLU };
2042+
2043+
if (std::find(valid_glu_ops.begin(), valid_glu_ops.end(), ggml_get_glu_op(glu)) == valid_glu_ops.end()) {
2044+
return false;
2045+
}
2046+
2047+
if (const bool swapped = ggml_get_op_params_i32(glu, 1); swapped) {
2048+
return false;
2049+
}
2050+
2051+
const bool split = ggml_backend_buft_is_cuda_split(ffn_up->src[0]->buffer->buft) ||
2052+
ggml_backend_buft_is_cuda_split(ffn_gate->src[0]->buffer->buft);
2053+
2054+
//TODO: add support for fusion for split buffers
2055+
if (split) {
2056+
return false;
2057+
}
2058+
2059+
return true;
2060+
}
2061+
2062+
static bool ggml_cuda_should_fuse_mul_mat_vec_f(const ggml_tensor * tensor) {
2063+
ggml_tensor * src0 = tensor->src[0];
2064+
ggml_tensor * src1 = tensor->src[1];
2065+
const ggml_tensor * dst = tensor;
2066+
2067+
const bool is_mul_mat_id = tensor->op == GGML_OP_MUL_MAT_ID;
2068+
2069+
bool use_mul_mat_vec_f =
2070+
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) &&
2071+
src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
2072+
2073+
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
2074+
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, is_mul_mat_id ? src1->ne[2]: src1->ne[1]);
2075+
2076+
if (tensor->op == GGML_OP_MUL_MAT_ID) {
2077+
use_mul_mat_vec_f = use_mul_mat_vec_f && dst->ne[2] == 1;
2078+
}
2079+
2080+
return use_mul_mat_vec_f;
2081+
}
2082+
2083+
static bool ggml_cuda_should_fuse_mul_mat_vec_q(const ggml_tensor * tensor) {
2084+
ggml_tensor * src0 = tensor->src[0];
2085+
ggml_tensor * src1 = tensor->src[1];
2086+
const ggml_tensor * dst = tensor;
2087+
2088+
const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE &&
2089+
ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) &&
2090+
src0->view_src;
2091+
2092+
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 &&
2093+
dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
2094+
2095+
if (tensor->op == GGML_OP_MUL_MAT_ID) {
2096+
use_mul_mat_vec_q = use_mul_mat_vec_q && dst->ne[2] == 1;
2097+
}
2098+
2099+
return use_mul_mat_vec_q;
2100+
}
2101+
20112102
static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
20122103
const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
20132104

@@ -2758,7 +2849,7 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
27582849
}
27592850
}
27602851

2761-
if (node->op == GGML_OP_SCALE &&
2852+
if ((node->op == GGML_OP_SCALE || node->op == GGML_OP_GLU) &&
27622853
memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
27632854
return false;
27642855
}
@@ -2857,6 +2948,25 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
28572948
}
28582949
}
28592950

2951+
std::initializer_list<enum ggml_op> mul_mat_id_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU };
2952+
2953+
std::initializer_list<enum ggml_op> mul_mat_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT, GGML_OP_GLU };
2954+
2955+
if (ops.size() == 3 && (std::equal(ops.begin(), ops.end(), mul_mat_id_glu_ops.begin()) ||
2956+
std::equal(ops.begin(), ops.end(), mul_mat_glu_ops.begin()))) {
2957+
if (node_idx + 2 >= cgraph->n_nodes) {
2958+
return false;
2959+
}
2960+
2961+
const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
2962+
const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 1];
2963+
const ggml_tensor * glu = cgraph->nodes[node_idx + 2];
2964+
2965+
if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu)) {
2966+
return true;
2967+
}
2968+
}
2969+
28602970
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
28612971
return false;
28622972
}
@@ -2988,6 +3098,36 @@ static void evaluate_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cuda_
29883098
}
29893099
}
29903100

3101+
bool fused_mul_mat_vec = false;
3102+
3103+
for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
3104+
if (ggml_cuda_can_fuse(cgraph, i, { op, op, GGML_OP_GLU }, {})) {
3105+
ggml_tensor * glu = cgraph->nodes[i + 2];
3106+
ggml_tensor * gate = glu->src[0];
3107+
ggml_tensor * up = glu->src[1];
3108+
3109+
const ggml_tensor * src0 = up->src[0];
3110+
const ggml_tensor * src1 = up->src[1];
3111+
const ggml_tensor * ids = up->src[2];
3112+
3113+
if (ggml_cuda_should_fuse_mul_mat_vec_f(up)) {
3114+
ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, gate->src[0], glu);
3115+
fused_mul_mat_vec = true;
3116+
break;
3117+
}
3118+
3119+
if (ggml_cuda_should_fuse_mul_mat_vec_q(up)) {
3120+
ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, gate->src[0], glu);
3121+
fused_mul_mat_vec = true;
3122+
break;
3123+
}
3124+
}
3125+
}
3126+
3127+
if (fused_mul_mat_vec) {
3128+
i += 2;
3129+
continue;
3130+
}
29913131

29923132
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) {
29933133
ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);

0 commit comments

Comments
 (0)