Skip to content

Commit d0cf314

Browse files
committed
CANN: add fused ffn op
1 parent 5fd160b commit d0cf314

File tree

12 files changed

+84
-2
lines changed

12 files changed

+84
-2
lines changed

common/arg.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1498,6 +1498,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
14981498
params.flash_attn = true;
14991499
}
15001500
).set_env("LLAMA_ARG_FLASH_ATTN"));
1501+
add_opt(common_arg(
1502+
{"-ffn", "--feed-forward-network"},
1503+
string_format("enable fused feed froward network (default: %s)", params.ffn ? "enabled" : "disabled"),
1504+
[](common_params & params) {
1505+
params.ffn = true;
1506+
}
1507+
).set_env("LLAMA_ARG_FFN"));
15011508
add_opt(common_arg(
15021509
{"-p", "--prompt"}, "PROMPT",
15031510
"prompt to start generation with; for system message, use -sys",

common/common.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,6 +1170,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
11701170
cparams.cb_eval_user_data = params.cb_eval_user_data;
11711171
cparams.offload_kqv = !params.no_kv_offload;
11721172
cparams.flash_attn = params.flash_attn;
1173+
cparams.ffn = params.ffn;
11731174
cparams.no_perf = params.no_perf;
11741175
cparams.op_offload = !params.no_op_offload;
11751176
cparams.swa_full = params.swa_full;

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ struct common_params {
347347
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
348348
bool cont_batching = true; // insert new sequences for decoding on-the-fly
349349
bool flash_attn = false; // flash attention
350+
bool ffn = false; // fused feed forward network
350351
bool no_perf = false; // disable performance metrics
351352
bool ctx_shift = true; // context shift on inifinite text generation
352353
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)

ggml/include/ggml.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,8 @@ extern "C" {
543543

544544
GGML_OP_GLU,
545545

546+
GGML_OP_FFN,
547+
546548
GGML_OP_COUNT,
547549
};
548550

@@ -2097,6 +2099,21 @@ extern "C" {
20972099
struct ggml_tensor * d,
20982100
bool masked);
20992101

2102+
GGML_API struct ggml_tensor * ggml_ffn_ext(
2103+
struct ggml_context * ctx,
2104+
struct ggml_tensor * cur,
2105+
struct ggml_tensor * up,
2106+
struct ggml_tensor * up_b,
2107+
struct ggml_tensor * up_s,
2108+
struct ggml_tensor * gate,
2109+
struct ggml_tensor * gate_b,
2110+
struct ggml_tensor * gate_s,
2111+
struct ggml_tensor * down,
2112+
struct ggml_tensor * down_b,
2113+
struct ggml_tensor * down_s,
2114+
struct ggml_tensor * act_scales,
2115+
int type_gate);
2116+
21002117
GGML_API struct ggml_tensor * ggml_ssm_conv(
21012118
struct ggml_context * ctx,
21022119
struct ggml_tensor * sx,

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3397,3 +3397,7 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
33973397
GGML_ABORT("Function is not implemented.");
33983398
}
33993399
}
3400+
3401+
void ggml_cann_ffn(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
3402+
3403+
}

ggml/src/ggml-cann/aclnn_ops.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,8 @@ void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst);
740740
*/
741741
void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst);
742742

743+
void ggml_cann_ffn(ggml_backend_cann_context& ctx, ggml_tensor* dst);
744+
743745
/*
744746
* @brief A generic wrapper for ACL resources with custom deleter support.
745747
*/

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1881,6 +1881,8 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
18811881
case GGML_OP_FLASH_ATTN_EXT:
18821882
ggml_cann_flash_attn_ext(ctx, dst);
18831883
break;
1884+
case GGML_OP_FFN:
1885+
ggml_cann_ffn(ctx, dst);
18841886
default:
18851887
return false;
18861888
}
@@ -2544,6 +2546,8 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
25442546
}
25452547
return true;
25462548
}
2549+
case GGML_OP_FFN:
2550+
return true;
25472551
default:
25482552
return false;
25492553
}

ggml/src/ggml.c

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,9 +1014,10 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
10141014
"OPT_STEP_ADAMW",
10151015

10161016
"GLU",
1017+
"FFN",
10171018
};
10181019

1019-
static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
1020+
static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");
10201021

10211022
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10221023
"none",
@@ -1115,9 +1116,10 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
11151116
"adamw(x)",
11161117

11171118
"glu(x)",
1119+
"ffn(x)",
11181120
};
11191121

1120-
static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
1122+
static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");
11211123

11221124
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
11231125

@@ -4958,6 +4960,40 @@ struct ggml_tensor * ggml_flash_attn_back(
49584960
return result;
49594961
}
49604962

4963+
struct ggml_tensor * ggml_ffn_ext(
4964+
struct ggml_context * ctx,
4965+
struct ggml_tensor * cur,
4966+
struct ggml_tensor * up,
4967+
struct ggml_tensor * up_b,
4968+
struct ggml_tensor * up_s,
4969+
struct ggml_tensor * gate,
4970+
struct ggml_tensor * gate_b,
4971+
struct ggml_tensor * gate_s,
4972+
struct ggml_tensor * down,
4973+
struct ggml_tensor * down_b,
4974+
struct ggml_tensor * down_s,
4975+
struct ggml_tensor * act_scales,
4976+
int type_op) {
4977+
int64_t ne[] = {10, 10, 10, 10};
4978+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, ne);
4979+
4980+
ggml_set_op_params_i32(result, 0, type_op);
4981+
4982+
result->op = GGML_OP_FFN;
4983+
result->src[0] = up;
4984+
result->src[1] = up_b;
4985+
result->src[2] = up_s;
4986+
result->src[3] = gate;
4987+
result->src[4] = gate_b;
4988+
result->src[5] = gate_s;
4989+
result->src[6] = down;
4990+
result->src[7] = down_b;
4991+
result->src[8] = down_s;
4992+
result->src[9] = act_scales;
4993+
4994+
return result;
4995+
}
4996+
49614997
// ggml_ssm_conv
49624998

49634999
struct ggml_tensor * ggml_ssm_conv(

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ extern "C" {
332332
bool embeddings; // if true, extract embeddings (together with logits)
333333
bool offload_kqv; // offload the KQV ops (including the KV cache) to GPU
334334
bool flash_attn; // use flash attention [EXPERIMENTAL]
335+
bool ffn; // use fused ffn
335336
bool no_perf; // measure performance timings
336337
bool op_offload; // offload host tensor operations to device
337338
bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)

src/llama-context.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ llama_context::llama_context(
4343
cparams.embeddings = params.embeddings;
4444
cparams.offload_kqv = params.offload_kqv;
4545
cparams.flash_attn = params.flash_attn;
46+
cparams.ffn = params.ffn;
4647
cparams.no_perf = params.no_perf;
4748
cparams.pooling_type = params.pooling_type;
4849
cparams.warmup = false;
@@ -2265,6 +2266,7 @@ llama_context_params llama_context_default_params() {
22652266
/*.embeddings =*/ false,
22662267
/*.offload_kqv =*/ true,
22672268
/*.flash_attn =*/ false,
2269+
/*.ffn =*/ false,
22682270
/*.no_perf =*/ true,
22692271
/*.op_offload =*/ true,
22702272
/*.swa_full =*/ true,

0 commit comments

Comments
 (0)