Skip to content

Commit bb6569e

Browse files
committed
llama-bench : add -gp <pp,tg> test measuring token generation rate at given prompt length
1 parent 017cc5f commit bb6569e

File tree

1 file changed

+88
-7
lines changed

1 file changed

+88
-7
lines changed

examples/llama-bench/llama-bench.cpp

Lines changed: 88 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ struct cmd_params {
160160
std::vector<int> n_prompt;
161161
std::vector<int> n_gen;
162162
std::vector<std::pair<int, int>> n_pg;
163+
std::vector<std::pair<int, int>> n_gp;
163164
std::vector<int> n_batch;
164165
std::vector<int> n_ubatch;
165166
std::vector<ggml_type> type_k;
@@ -192,6 +193,7 @@ static const cmd_params cmd_params_defaults = {
192193
/* n_prompt */ { 512 },
193194
/* n_gen */ { 128 },
194195
/* n_pg */ {},
196+
/* n_gp */ {},
195197
/* n_batch */ { 2048 },
196198
/* n_ubatch */ { 512 },
197199
/* type_k */ { GGML_TYPE_F16 },
@@ -230,6 +232,8 @@ static void print_usage(int /* argc */, char ** argv) {
230232
printf(" -n, --n-gen <n> (default: %s)\n", join(cmd_params_defaults.n_gen, ",").c_str());
231233
printf(" -pg <pp,tg> (default: %s)\n",
232234
join(transform_to_str(cmd_params_defaults.n_pg, pair_str), ",").c_str());
235+
printf(" -gp <pp,tg> (default: %s)\n",
236+
join(transform_to_str(cmd_params_defaults.n_gp, pair_str), ",").c_str());
233237
printf(" -b, --batch-size <n> (default: %s)\n",
234238
join(cmd_params_defaults.n_batch, ",").c_str());
235239
printf(" -ub, --ubatch-size <n> (default: %s)\n",
@@ -366,6 +370,17 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
366370
break;
367371
}
368372
params.n_pg.push_back({ std::stoi(p[0]), std::stoi(p[1]) });
373+
} else if (arg == "-gp") {
374+
if (++i >= argc) {
375+
invalid_param = true;
376+
break;
377+
}
378+
auto p = string_split<std::string>(argv[i], ',');
379+
if (p.size() != 2) {
380+
invalid_param = true;
381+
break;
382+
}
383+
params.n_gp.push_back({ std::stoi(p[0]), std::stoi(p[1]) });
369384
} else if (arg == "-b" || arg == "--batch-size") {
370385
if (++i >= argc) {
371386
invalid_param = true;
@@ -615,6 +630,9 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
615630
if (params.n_pg.empty()) {
616631
params.n_pg = cmd_params_defaults.n_pg;
617632
}
633+
if (params.n_gp.empty()) {
634+
params.n_gp = cmd_params_defaults.n_gp;
635+
}
618636
if (params.n_batch.empty()) {
619637
params.n_batch = cmd_params_defaults.n_batch;
620638
}
@@ -670,7 +688,19 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
670688
return params;
671689
}
672690

691+
enum test_kind_type {
692+
// measure mean prompt processing rate without token generation
693+
TEST_KIND_PP,
694+
// measure mean token generation rate without prompt processing
695+
TEST_KIND_TG,
696+
// measure mean prompt processing and token generation rate
697+
TEST_KIND_PG,
698+
// measure mean token generation rate after processing prompt of given length
699+
TEST_KIND_GP,
700+
};
701+
673702
struct cmd_params_instance {
703+
test_kind_type test_kind;
674704
std::string model;
675705
int n_prompt;
676706
int n_gen;
@@ -757,6 +787,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
757787
continue;
758788
}
759789
cmd_params_instance instance = {
790+
/* .test_kind = */ TEST_KIND_PP,
760791
/* .model = */ m,
761792
/* .n_prompt = */ n_prompt,
762793
/* .n_gen = */ 0,
@@ -786,6 +817,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
786817
continue;
787818
}
788819
cmd_params_instance instance = {
820+
/* .test_kind = */ TEST_KIND_TG,
789821
/* .model = */ m,
790822
/* .n_prompt = */ 0,
791823
/* .n_gen = */ n_gen,
@@ -815,6 +847,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
815847
continue;
816848
}
817849
cmd_params_instance instance = {
850+
/* .test_kind = */ TEST_KIND_PG,
818851
/* .model = */ m,
819852
/* .n_prompt = */ n_pg.first,
820853
/* .n_gen = */ n_pg.second,
@@ -838,6 +871,36 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
838871
};
839872
instances.push_back(instance);
840873
}
874+
875+
for (const auto & n_gp : params.n_gp) {
876+
if (n_gp.first == 0 && n_gp.second == 0) {
877+
continue;
878+
}
879+
cmd_params_instance instance = {
880+
/* .test_kind = */ TEST_KIND_GP,
881+
/* .model = */ m,
882+
/* .n_prompt = */ n_gp.first,
883+
/* .n_gen = */ n_gp.second,
884+
/* .n_batch = */ nb,
885+
/* .n_ubatch = */ nub,
886+
/* .type_k = */ tk,
887+
/* .type_v = */ tv,
888+
/* .n_threads = */ nt,
889+
/* .cpu_mask = */ cm,
890+
/* .cpu_strict = */ cs,
891+
/* .poll = */ pl,
892+
/* .n_gpu_layers = */ nl,
893+
/* .rpc_servers = */ rpc,
894+
/* .split_mode = */ sm,
895+
/* .main_gpu = */ mg,
896+
/* .no_kv_offload= */ nkvo,
897+
/* .flash_attn = */ fa,
898+
/* .tensor_split = */ ts,
899+
/* .use_mmap = */ mmp,
900+
/* .embeddings = */ embd,
901+
};
902+
instances.push_back(instance);
903+
}
841904
}
842905
// clang-format on
843906

@@ -853,6 +916,7 @@ struct test {
853916
std::string model_type;
854917
uint64_t model_size;
855918
uint64_t model_n_params;
919+
test_kind_type test_kind;
856920
int n_batch;
857921
int n_ubatch;
858922
int n_threads;
@@ -881,6 +945,7 @@ struct test {
881945
model_type = buf;
882946
model_size = llama_model_size(lmodel);
883947
model_n_params = llama_model_n_params(lmodel);
948+
test_kind = inst.test_kind;
884949
n_batch = inst.n_batch;
885950
n_ubatch = inst.n_ubatch;
886951
n_threads = inst.n_threads;
@@ -912,7 +977,7 @@ struct test {
912977
uint64_t stdev_ns() const { return ::stdev(samples_ns); }
913978

914979
std::vector<double> get_ts() const {
915-
int n_tokens = n_prompt + n_gen;
980+
int n_tokens = (test_kind == TEST_KIND_GP ? 0 : n_prompt) + n_gen;
916981
std::vector<double> ts;
917982
std::transform(samples_ns.begin(), samples_ns.end(), std::back_inserter(ts),
918983
[n_tokens](uint64_t t) { return 1e9 * n_tokens / t; });
@@ -1325,12 +1390,22 @@ struct markdown_printer : public printer {
13251390
} else if (field == "backend") {
13261391
value = test::get_backend();
13271392
} else if (field == "test") {
1328-
if (t.n_prompt > 0 && t.n_gen == 0) {
1329-
snprintf(buf, sizeof(buf), "pp%d", t.n_prompt);
1330-
} else if (t.n_gen > 0 && t.n_prompt == 0) {
1331-
snprintf(buf, sizeof(buf), "tg%d", t.n_gen);
1332-
} else {
1333-
snprintf(buf, sizeof(buf), "pp%d+tg%d", t.n_prompt, t.n_gen);
1393+
switch (t.test_kind) {
1394+
case TEST_KIND_PP:
1395+
snprintf(buf, sizeof(buf), "pp%d", t.n_prompt);
1396+
break;
1397+
case TEST_KIND_TG:
1398+
snprintf(buf, sizeof(buf), "tg%d", t.n_gen);
1399+
break;
1400+
case TEST_KIND_PG:
1401+
snprintf(buf, sizeof(buf), "pp%d+tg%d", t.n_prompt, t.n_gen);
1402+
break;
1403+
case TEST_KIND_GP:
1404+
snprintf(buf, sizeof(buf), "tg%d@pp%d", t.n_gen, t.n_prompt);
1405+
break;
1406+
default:
1407+
assert(false);
1408+
exit(1);
13341409
}
13351410
value = buf;
13361411
} else if (field == "t/s") {
@@ -1597,6 +1672,12 @@ int main(int argc, char ** argv) {
15971672
}
15981673
test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
15991674
}
1675+
1676+
// we are not interested in prompt processing time in g@p test
1677+
if (t.test_kind == TEST_KIND_GP) {
1678+
t_start = get_time_ns();
1679+
}
1680+
16001681
if (t.n_gen > 0) {
16011682
if (params.progress) {
16021683
fprintf(stderr, "llama-bench: benchmark %d/%zu: generation run %d/%d\n", params_idx, params_count,

0 commit comments

Comments
 (0)