Skip to content

Commit 1e6f8ff

Browse files
ikawrakowIwan Kawrakow
andauthored
Add rcache to llama-bench (#936)
Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 489554b commit 1e6f8ff

File tree

1 file changed

+30
-5
lines changed

1 file changed

+30
-5
lines changed

examples/llama-bench/llama-bench.cpp

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ struct cmd_params {
267267
bool use_thp = false;
268268
bool no_ooae = false;
269269
bool mqkv = false;
270+
bool rcache = false;
270271
output_formats output_format;
271272
output_formats output_format_stderr;
272273
};
@@ -307,6 +308,7 @@ static const cmd_params cmd_params_defaults = {
307308
/* use_thp */ false,
308309
/* no_ooae */ false,
309310
/* mqkv */ false,
311+
/* rcache */ false,
310312
/* output_format */ MARKDOWN,
311313
/* output_format_stderr */ NONE,
312314
};
@@ -348,6 +350,7 @@ static void print_usage(int /* argc */, char ** argv) {
348350
printf(" -rtr, --run-time-repack <0|1> (default: %s)\n", cmd_params_defaults.repack ? "1" : "0");
349351
printf(" -cuda, --cuda-params <string> (default: %s)\n", cmd_params_defaults.repack ? "1" : "0");
350352
printf(" -mqkv, --merge-qkv (default: %s)\n", cmd_params_defaults.mqkv ? "1" : "0");
353+
printf(" -rcache, --rope-cache (default: %s)\n", cmd_params_defaults.rcache ? "1" : "0");
351354
printf(" -thp, --transparent-huge-pages <0|1> (default: %s)\n", cmd_params_defaults.use_thp? "1" : "0");
352355
printf(" -ot, --override-tensor pattern (default: none)\n");
353356
printf(" -fmoe, --fused-moe <0|1> (default: %s)\n", cmd_params_defaults.fmoe? "1" : "0");
@@ -751,6 +754,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
751754
break;
752755
}
753756
params.mqkv = std::stoi(argv[i]);
757+
} else if (arg == "-rcache" || arg == "--rope-cache") {
758+
if (++i >= argc) {
759+
invalid_param = true;
760+
break;
761+
}
762+
params.rcache = std::stoi(argv[i]);
754763
} else if (arg == "-thp" || arg == "--transparent-huge-pages") {
755764
if (++i >= argc) {
756765
invalid_param = true;
@@ -871,6 +880,7 @@ struct cmd_params_instance {
871880
bool use_thp = false;
872881
bool no_ooae = false;
873882
bool mqkv = false;
883+
bool rcache = false;
874884
const llama_model_tensor_buft_override* buft_overrides;
875885

876886
llama_model_params to_llama_mparams() const {
@@ -919,6 +929,7 @@ struct cmd_params_instance {
919929
cparams.attn_max_batch = attn_max_batch;
920930
cparams.fused_moe_up_gate = fmoe;
921931
cparams.grouped_expert_routing = ger;
932+
cparams.rope_cache = rcache;
922933
cparams.fused_up_gate = !no_fug;
923934
cparams.only_active_experts = !no_ooae;
924935
cparams.min_experts = ser.first;
@@ -986,6 +997,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
986997
/* .use_thp = */ params.use_thp,
987998
/* .no_ooae = */ params.no_ooae,
988999
/* .mqkv = */ params.mqkv,
1000+
/* .rcache = */ params.rcache,
9891001
/* .buft_overrides=*/ params.buft_overrides.data(),
9901002
};
9911003
instances.push_back(instance);
@@ -1025,6 +1037,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
10251037
/* .use_thp = */ params.use_thp,
10261038
/* .no_ooae = */ params.no_ooae,
10271039
/* .mqkv = */ params.mqkv,
1040+
/* .rcache = */ params.rcache,
10281041
/* .buft_overrides=*/ params.buft_overrides.data(),
10291042
};
10301043
instances.push_back(instance);
@@ -1064,6 +1077,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
10641077
/* .use_thp = */ params.use_thp,
10651078
/* .no_ooae = */ params.no_ooae,
10661079
/* .mqkv = */ params.mqkv,
1080+
/* .rcache = */ params.rcache,
10671081
/* .buft_overrides=*/ params.buft_overrides.data(),
10681082
};
10691083
instances.push_back(instance);
@@ -1103,6 +1117,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
11031117
/* .use_thp = */ params.use_thp,
11041118
/* .no_ooae = */ params.no_ooae,
11051119
/* .mqkv = */ params.mqkv,
1120+
/* .rcache = */ params.rcache,
11061121
/* .buft_overrides=*/ params.buft_overrides.data(),
11071122
};
11081123
instances.push_back(instance);
@@ -1153,6 +1168,7 @@ struct test {
11531168
bool use_thp = false;
11541169
bool no_ooae = false;
11551170
bool mqkv = false;
1171+
bool rcache = false;
11561172
int n_prompt;
11571173
int n_gen;
11581174
std::string test_time;
@@ -1189,6 +1205,7 @@ struct test {
11891205
mqkv = inst.mqkv;
11901206
fmoe = inst.fmoe;
11911207
ger = inst.ger;
1208+
rcache = inst.rcache;
11921209
no_fug = inst.no_fug;
11931210
use_thp = inst.use_thp;
11941211
no_ooae = inst.no_ooae;
@@ -1282,7 +1299,8 @@ struct test {
12821299
"n_threads", "type_k", "type_v",
12831300
"n_gpu_layers", "split_mode",
12841301
"main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch", "ser",
1285-
"tensor_split", "use_mmap", "embeddings", "repack", "mqkv", "fused_moe", "grouped_er", "fused_up_gate", "use_thp", "ooae",
1302+
"tensor_split", "use_mmap", "embeddings", "repack", "mqkv", "fused_moe", "grouped_er",
1303+
"fused_up_gate", "use_thp", "ooae", "rcache",
12861304
"n_prompt", "n_gen", "test_time",
12871305
"avg_ns", "stddev_ns",
12881306
"avg_ts", "stddev_ts", "test",
@@ -1304,7 +1322,8 @@ struct test {
13041322
if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" ||
13051323
field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" ||
13061324
field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack" || field == "use_thp" ||
1307-
field == "fused_moe" || field == "grouped_er" || field == "fused_up_gate" || field == "ooae" || field == "mqkv") {
1325+
field == "fused_moe" || field == "grouped_er" || field == "fused_up_gate" || field == "ooae" || field == "mqkv" ||
1326+
field == "rcache") {
13081327
return BOOL;
13091328
}
13101329
if (field == "avg_ts" || field == "stddev_ts") {
@@ -1347,7 +1366,7 @@ struct test {
13471366
std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn),
13481367
std::to_string(mla_attn), std::to_string(attn_max_batch), ser_to_string(ser),
13491368
tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings),
1350-
std::to_string(repack), std::to_string(fmoe), std::to_string(ger),
1369+
std::to_string(repack), std::to_string(fmoe), std::to_string(ger), std::to_string(rcache),
13511370
std::to_string(no_fug), std::to_string(use_thp), std::to_string(no_ooae), std::to_string(mqkv),
13521371
std::to_string(n_prompt), std::to_string(n_gen), test_time,
13531372
std::to_string(avg_ns()), std::to_string(stdev_ns()),
@@ -1538,6 +1557,9 @@ struct markdown_printer : public printer {
15381557
if (field == "grouped_er") {
15391558
return 3;
15401559
}
1560+
if (field == "rcache") {
1561+
return 6;
1562+
}
15411563
if (field == "fused_up_gate") {
15421564
return 6;
15431565
}
@@ -1599,8 +1621,8 @@ struct markdown_printer : public printer {
15991621
if (field == "grouped_er") {
16001622
return "ger";
16011623
}
1602-
if (field == "grouped_er") {
1603-
return "ger";
1624+
if (field == "rcache") {
1625+
return "rcache";
16041626
}
16051627
if (field == "fused_up_gate") {
16061628
return "no-fug";
@@ -1687,6 +1709,9 @@ struct markdown_printer : public printer {
16871709
if (params.ger != cmd_params_defaults.ger) {
16881710
fields.emplace_back("grouped_er");
16891711
}
1712+
if (params.rcache != cmd_params_defaults.rcache) {
1713+
fields.emplace_back("rcache");
1714+
}
16901715
if (params.no_fug != cmd_params_defaults.no_fug) {
16911716
fields.emplace_back("fused_up_gate");
16921717
}

0 commit comments

Comments
 (0)