Skip to content

Commit 1789de5

Browse files
ikawrakowIwan Kawrakow
andauthored
Make ooae on by default and add to llama-bench (ikawrakow#842)
Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 2f5dae2 commit 1789de5

File tree

3 files changed

+31
-5
lines changed

3 files changed

+31
-5
lines changed

common/common.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1390,7 +1390,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
13901390
}
13911391
return true;
13921392
}
1393-
if (arg == "--offload-only-active-experts" || arg == "-ooae") {
1393+
if (arg == "--no-offload-only-active-experts" || arg == "-no-ooae") {
13941394
params.only_active_exps = true;
13951395
return true;
13961396
}

common/common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ struct gpt_params {
255255
bool repack_tensors = false; // repack tensors if interleaved variant is available
256256
bool use_thp = false; // use transparent huge pages (linux only)
257257
bool validate_quants = false; // if true, check for NaNs while loading the model
258-
bool only_active_exps = false; // if true, offload only active experts (relevant only for hybrid CPU/GPU)
258+
bool only_active_exps = true; // if true, offload only active experts (relevant only for hybrid CPU/GPU)
259259

260260
std::string cache_type_k = "f16"; // KV cache data type for the K
261261
std::string cache_type_v = "f16"; // KV cache data type for the V

examples/llama-bench/llama-bench.cpp

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ struct cmd_params {
264264
bool ger = false; // ger = Grouped Expert Routing
265265
bool no_fug = false;
266266
bool use_thp = false;
267+
bool no_ooae = false;
267268
output_formats output_format;
268269
output_formats output_format_stderr;
269270
};
@@ -301,6 +302,7 @@ static const cmd_params cmd_params_defaults = {
301302
/* ger */ false,
302303
/* no_fug */ false,
303304
/* use_thp */ false,
305+
/* no_ooae */ false,
304306
/* output_format */ MARKDOWN,
305307
/* output_format_stderr */ NONE,
306308
};
@@ -345,6 +347,7 @@ static void print_usage(int /* argc */, char ** argv) {
345347
printf(" -fmoe, --fused-moe <0|1> (default: %s)\n", cmd_params_defaults.fmoe? "1" : "0");
346348
printf(" -ger, --grouped-expert-routing <0|1>(default: %s)\n", cmd_params_defaults.ger ? "1" : "0");
347349
printf(" -no-fug, --no-fused-up-gate <0|1> (default: %s)\n", cmd_params_defaults.no_fug? "1" : "0");
350+
printf(" -no-ooae, --no-offload-only-active-experts <0|1> (default: %s)\n", cmd_params_defaults.no_ooae? "1" : "0");
348351
printf("\n");
349352
printf("Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n");
350353
}
@@ -754,6 +757,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
754757
break;
755758
}
756759
params.no_fug = std::stoi(argv[i]);
760+
} else if (arg == "-no-ooae" || arg == "--no-offload-only-active-experts") {
761+
if (++i >= argc) {
762+
invalid_param = true;
763+
break;
764+
}
765+
params.no_ooae = std::stoi(argv[i]);
757766
} else if (arg == "-ot" || arg == "--override-tensor") {
758767
if (++i >= argc) {
759768
invalid_param = true;
@@ -841,6 +850,7 @@ struct cmd_params_instance {
841850
bool ger = false;
842851
bool no_fug = false;
843852
bool use_thp = false;
853+
bool no_ooae = false;
844854
const llama_model_tensor_buft_override* buft_overrides;
845855

846856
llama_model_params to_llama_mparams() const {
@@ -888,6 +898,7 @@ struct cmd_params_instance {
888898
cparams.fused_moe_up_gate = fmoe;
889899
cparams.grouped_expert_routing = ger;
890900
cparams.fused_up_gate = !no_fug;
901+
cparams.only_active_experts = !no_ooae;
891902
cparams.min_experts = ser.first;
892903
cparams.thresh_experts = ser.second;
893904
cparams.embeddings = embeddings;
@@ -949,6 +960,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
949960
/* .ger = */ params.ger,
950961
/* .no_fug = */ params.no_fug,
951962
/* .use_thp = */ params.use_thp,
963+
/* .no_ooae = */ params.no_ooae,
952964
/* .buft_overrides=*/ params.buft_overrides.data(),
953965
};
954966
instances.push_back(instance);
@@ -985,6 +997,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
985997
/* .ger = */ params.ger,
986998
/* .no_fug = */ params.no_fug,
987999
/* .use_thp = */ params.use_thp,
1000+
/* .no_ooae = */ params.no_ooae,
9881001
/* .buft_overrides=*/ params.buft_overrides.data(),
9891002
};
9901003
instances.push_back(instance);
@@ -1021,6 +1034,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
10211034
/* .ger = */ params.ger,
10221035
/* .no_fug = */ params.no_fug,
10231036
/* .use_thp = */ params.use_thp,
1037+
/* .no_ooae = */ params.no_ooae,
10241038
/* .buft_overrides=*/ params.buft_overrides.data(),
10251039
};
10261040
instances.push_back(instance);
@@ -1057,6 +1071,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
10571071
/* .ger = */ params.ger,
10581072
/* .no_fug = */ params.no_fug,
10591073
/* .use_thp = */ params.use_thp,
1074+
/* .no_ooae = */ params.no_ooae,
10601075
/* .buft_overrides=*/ params.buft_overrides.data(),
10611076
};
10621077
instances.push_back(instance);
@@ -1104,6 +1119,7 @@ struct test {
11041119
bool ger = false;
11051120
bool no_fug = false;
11061121
bool use_thp = false;
1122+
bool no_ooae = false;
11071123
int n_prompt;
11081124
int n_gen;
11091125
std::string test_time;
@@ -1140,6 +1156,7 @@ struct test {
11401156
ger = inst.ger;
11411157
no_fug = inst.no_fug;
11421158
use_thp = inst.use_thp;
1159+
no_ooae = inst.no_ooae;
11431160
n_prompt = inst.n_prompt;
11441161
n_gen = inst.n_gen;
11451162
test_kind = inst.test_kind;
@@ -1230,7 +1247,7 @@ struct test {
12301247
"n_threads", "type_k", "type_v",
12311248
"n_gpu_layers", "split_mode",
12321249
"main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch", "ser",
1233-
"tensor_split", "use_mmap", "embeddings", "repack", "fused_moe", "grouped_er", "fused_up_gate", "use_thp",
1250+
"tensor_split", "use_mmap", "embeddings", "repack", "fused_moe", "grouped_er", "fused_up_gate", "use_thp", "ooae",
12341251
"n_prompt", "n_gen", "test_time",
12351252
"avg_ns", "stddev_ns",
12361253
"avg_ts", "stddev_ts", "test",
@@ -1252,7 +1269,7 @@ struct test {
12521269
if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" ||
12531270
field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" ||
12541271
field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack" || field == "use_thp" ||
1255-
field == "fused_moe" || field == "grouped_er" || field == "fused_up_gate") {
1272+
field == "fused_moe" || field == "grouped_er" || field == "fused_up_gate" || field == "ooae") {
12561273
return BOOL;
12571274
}
12581275
if (field == "avg_ts" || field == "stddev_ts") {
@@ -1296,7 +1313,7 @@ struct test {
12961313
std::to_string(mla_attn), std::to_string(attn_max_batch), ser_to_string(ser),
12971314
tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings),
12981315
std::to_string(repack), std::to_string(fmoe), std::to_string(ger),
1299-
std::to_string(no_fug), std::to_string(use_thp),
1316+
std::to_string(no_fug), std::to_string(use_thp), std::to_string(no_ooae),
13001317
std::to_string(n_prompt), std::to_string(n_gen), test_time,
13011318
std::to_string(avg_ns()), std::to_string(stdev_ns()),
13021319
std::to_string(avg_ts()), std::to_string(stdev_ts()),
@@ -1486,6 +1503,9 @@ struct markdown_printer : public printer {
14861503
if (field == "fused_up_gate") {
14871504
return 6;
14881505
}
1506+
if (field == "ooae") {
1507+
return 7;
1508+
}
14891509
if (field == "test") {
14901510
return 13;
14911511
}
@@ -1544,6 +1564,9 @@ struct markdown_printer : public printer {
15441564
if (field == "fused_up_gate") {
15451565
return "no-fug";
15461566
}
1567+
if (field == "ooae") {
1568+
return "no-ooae";
1569+
}
15471570
if (field == "embeddings") {
15481571
return "embd";
15491572
}
@@ -1623,6 +1646,9 @@ struct markdown_printer : public printer {
16231646
if (params.no_fug != cmd_params_defaults.no_fug) {
16241647
fields.emplace_back("fused_up_gate");
16251648
}
1649+
if (params.no_ooae != cmd_params_defaults.no_ooae) {
1650+
fields.emplace_back("ooae");
1651+
}
16261652
fields.emplace_back("test");
16271653
fields.emplace_back("t/s");
16281654

0 commit comments

Comments
 (0)