Skip to content

Commit 8780a09

Browse files
authored
Merge branch 'ikawrakow:main' into main
2 parents 7a461f1 + 16f30fc commit 8780a09

File tree

4 files changed

+26
-26
lines changed

4 files changed

+26
-26
lines changed

common/common.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -994,8 +994,8 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
994994
params.cont_batching = false;
995995
return true;
996996
}
997-
if (arg == "-fa" || arg == "--flash-attn") {
998-
params.flash_attn = true;
997+
if (arg == "-no-fa" || arg == "--no-flash-attn") {
998+
params.flash_attn = false;
999999
return true;
10001000
}
10011001
if (arg == "-mla" || arg == "--mla-use") {
@@ -1008,8 +1008,8 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
10081008
params.attn_max_batch = std::stoi(argv[i]);
10091009
return true;
10101010
}
1011-
if (arg == "-fmoe" || arg == "--fused-moe") {
1012-
params.fused_moe_up_gate = true;
1011+
if (arg == "-no-fmoe" || arg == "--no-fused-moe") {
1012+
params.fused_moe_up_gate = false;
10131013
return true;
10141014
}
10151015
if (arg == "-ger" || arg == "--grouped-expert-routing") {
@@ -1804,10 +1804,10 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
18041804
options.push_back({ "*", "-ub, --ubatch-size N", "physical maximum batch size (default: %d)", params.n_ubatch });
18051805
options.push_back({ "*", " --keep N", "number of tokens to keep from the initial prompt (default: %d, -1 = all)", params.n_keep });
18061806
options.push_back({ "*", " --chunks N", "max number of chunks to process (default: %d, -1 = all)", params.n_chunks });
1807-
options.push_back({ "*", "-fa, --flash-attn", "enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" });
1807+
options.push_back({ "*", "-no-fa, --no-flash-attn", "disable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" });
18081808
options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %d)", params.mla_attn });
18091809
options.push_back({ "*", "-amb, --attention-max-batch", "max batch size for attention computations (default: %d)", params.attn_max_batch});
1810-
options.push_back({ "*", "-fmoe, --fused-moe", "enable fused MoE (default: %s)", params.fused_moe_up_gate ? "enabled" : "disabled" });
1810+
options.push_back({ "*", "-no-fmoe, --no-fused-moe", "disable fused MoE (default: %s)", params.fused_moe_up_gate ? "enabled" : "disabled" });
18111811
options.push_back({ "*", "-ger, --grouped-expert-routing", "enable grouped expert routing (default: %s)", params.grouped_expert_routing ? "enabled" : "disabled" });
18121812
options.push_back({ "*", "-no-fug, --no-fused-up-gate", "disaable fused up-gate (default: %s)", params.fused_up_gate ? "enabled" : "disabled" });
18131813
options.push_back({ "*", "-no-mmad, --no-fused-mul-multiadd", "disaable fused mul-multi_add (default: %s)", params.fused_mmad? "enabled" : "disabled" });

common/common.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,10 +230,10 @@ struct gpt_params {
230230
bool multiline_input = false; // reverse the usage of `\`
231231
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
232232
bool cont_batching = true; // insert new sequences for decoding on-the-fly
233-
bool flash_attn = false; // flash attention
233+
bool flash_attn = true; // flash attention
234234
int mla_attn = 0; // MLA 0: standard attention, 1: MLA with K and transposed V cache, 2: MLA with just K cache
235235
int attn_max_batch = 0; // Max batch size to use when computing attention (only applicable if flash_attn = false)
236-
bool fused_moe_up_gate = false; // fused up*unary(gate) op for MoE models
236+
bool fused_moe_up_gate = true; // fused up*unary(gate) op for MoE models
237237
bool fused_up_gate = true; // fused up*unary(gate) op
238238
bool fused_mmad = true; // fused mul+multi_add op
239239
bool grouped_expert_routing = false; // if to use grouped expert routing (BailingMoeV2 arch)

examples/llama-bench/llama-bench.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ struct cmd_params {
260260
bool verbose;
261261
bool warmup;
262262
bool repack = false;
263-
bool fmoe = false;
263+
bool fmoe = true;
264264
bool ger = false; // ger = Grouped Expert Routing
265265
bool no_fug = false;
266266
bool use_thp = false;
@@ -285,7 +285,7 @@ static const cmd_params cmd_params_defaults = {
285285
/* split_mode */ {LLAMA_SPLIT_MODE_LAYER},
286286
/* main_gpu */ {0},
287287
/* no_kv_offload */ {false},
288-
/* flash_attn */ {false},
288+
/* flash_attn */ {true},
289289
/* mla_attn */ {0},
290290
/* attn_max_batch */ {0},
291291
/* ser */ {{-1,0.0f}},
@@ -298,7 +298,7 @@ static const cmd_params cmd_params_defaults = {
298298
/* verbose */ false,
299299
/* warmup */ true,
300300
/* repack */ false,
301-
/* fmoe */ false,
301+
/* fmoe */ true,
302302
/* ger */ false,
303303
/* no_fug */ false,
304304
/* use_thp */ false,
@@ -846,7 +846,7 @@ struct cmd_params_instance {
846846
bool use_mmap;
847847
bool embeddings;
848848
bool repack = false;
849-
bool fmoe = false;
849+
bool fmoe = true;
850850
bool ger = false;
851851
bool no_fug = false;
852852
bool use_thp = false;

src/llama.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3750,10 +3750,10 @@ struct llama_context_params llama_context_default_params() {
37503750
/*.logits_all =*/ false,
37513751
/*.embeddings =*/ false,
37523752
/*.offload_kqv =*/ true,
3753-
/*.flash_attn =*/ false,
3753+
/*.flash_attn =*/ true,
37543754
/*.mla_attn =*/ 0,
37553755
/*.attn_max_batch =*/ 0,
3756-
/*.fused_moe_up_gate =*/ false,
3756+
/*.fused_moe_up_gate =*/ true,
37573757
/*.grouped_expert_routing =*/ false,
37583758
/*.fused_up_gate =*/ true,
37593759
/*.fused_mmad =*/ true,
@@ -4040,19 +4040,19 @@ struct llama_context * llama_new_context_with_model(
40404040
cparams.mla_attn = 0;
40414041
}
40424042

4043-
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
4044-
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
4045-
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
4046-
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
4047-
LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn);
4048-
LLAMA_LOG_INFO("%s: attn_max_b = %d\n", __func__, cparams.attn_max_batch);
4049-
LLAMA_LOG_INFO("%s: fused_moe = %d\n", __func__, cparams.fused_moe_up_gate);
4050-
LLAMA_LOG_INFO("%s: grouped er = %d\n", __func__, cparams.grouped_expert_routing);
4043+
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
4044+
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
4045+
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
4046+
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
4047+
LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn);
4048+
LLAMA_LOG_INFO("%s: attn_max_b = %d\n", __func__, cparams.attn_max_batch);
4049+
LLAMA_LOG_INFO("%s: fused_moe = %d\n", __func__, cparams.fused_moe_up_gate);
4050+
LLAMA_LOG_INFO("%s: grouped er = %d\n", __func__, cparams.grouped_expert_routing);
40514051
LLAMA_LOG_INFO("%s: fused_up_gate = %d\n", __func__, cparams.fused_up_gate);
4052-
LLAMA_LOG_INFO("%s: fused_mmad = %d\n", __func__, cparams.fused_mmad);
4053-
LLAMA_LOG_INFO("%s: ser = %d, %g\n", __func__, cparams.min_experts, cparams.thresh_experts);
4054-
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
4055-
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
4052+
LLAMA_LOG_INFO("%s: fused_mmad = %d\n", __func__, cparams.fused_mmad);
4053+
LLAMA_LOG_INFO("%s: ser = %d, %g\n", __func__, cparams.min_experts, cparams.thresh_experts);
4054+
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
4055+
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
40564056

40574057
ctx->abort_callback = params.abort_callback;
40584058
ctx->abort_callback_data = params.abort_callback_data;

0 commit comments

Comments
 (0)