@@ -160,6 +160,7 @@ struct cmd_params {
160160 std::vector<int > n_prompt;
161161 std::vector<int > n_gen;
162162 std::vector<std::pair<int , int >> n_pg;
163+ std::vector<std::pair<int , int >> n_gp;
163164 std::vector<int > n_batch;
164165 std::vector<int > n_ubatch;
165166 std::vector<ggml_type> type_k;
@@ -192,6 +193,7 @@ static const cmd_params cmd_params_defaults = {
192193 /* n_prompt */ { 512 },
193194 /* n_gen */ { 128 },
194195 /* n_pg */ {},
196+ /* n_gp */ {},
195197 /* n_batch */ { 2048 },
196198 /* n_ubatch */ { 512 },
197199 /* type_k */ { GGML_TYPE_F16 },
@@ -230,6 +232,8 @@ static void print_usage(int /* argc */, char ** argv) {
230232 printf (" -n, --n-gen <n> (default: %s)\n " , join (cmd_params_defaults.n_gen , " ," ).c_str ());
231233 printf (" -pg <pp,tg> (default: %s)\n " ,
232234 join (transform_to_str (cmd_params_defaults.n_pg , pair_str), " ," ).c_str ());
235+ printf (" -gp <pp,tg> (default: %s)\n " ,
236+ join (transform_to_str (cmd_params_defaults.n_gp , pair_str), " ," ).c_str ());
233237 printf (" -b, --batch-size <n> (default: %s)\n " ,
234238 join (cmd_params_defaults.n_batch , " ," ).c_str ());
235239 printf (" -ub, --ubatch-size <n> (default: %s)\n " ,
@@ -366,6 +370,17 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
366370 break ;
367371 }
368372 params.n_pg .push_back ({ std::stoi (p[0 ]), std::stoi (p[1 ]) });
373+ } else if (arg == " -gp" ) {
374+ if (++i >= argc) {
375+ invalid_param = true ;
376+ break ;
377+ }
378+ auto p = string_split<std::string>(argv[i], ' ,' );
379+ if (p.size () != 2 ) {
380+ invalid_param = true ;
381+ break ;
382+ }
383+ params.n_gp .push_back ({ std::stoi (p[0 ]), std::stoi (p[1 ]) });
369384 } else if (arg == " -b" || arg == " --batch-size" ) {
370385 if (++i >= argc) {
371386 invalid_param = true ;
@@ -615,6 +630,9 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
615630 if (params.n_pg .empty ()) {
616631 params.n_pg = cmd_params_defaults.n_pg ;
617632 }
633+ if (params.n_gp .empty ()) {
634+ params.n_gp = cmd_params_defaults.n_gp ;
635+ }
618636 if (params.n_batch .empty ()) {
619637 params.n_batch = cmd_params_defaults.n_batch ;
620638 }
@@ -670,7 +688,19 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
670688 return params;
671689}
672690
691+ enum test_kind_type {
692+ // measure mean prompt processing rate without token generation
693+ TEST_KIND_PP,
694+ // measure mean token generation rate without prompt processing
695+ TEST_KIND_TG,
696+ // measure mean prompt processing and token generation rate
697+ TEST_KIND_PG,
698+ // measure mean token generation rate after processing prompt of given length
699+ TEST_KIND_GP,
700+ };
701+
673702struct cmd_params_instance {
703+ test_kind_type test_kind;
674704 std::string model;
675705 int n_prompt;
676706 int n_gen;
@@ -757,6 +787,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
757787 continue ;
758788 }
759789 cmd_params_instance instance = {
790+ /* .test_kind = */ TEST_KIND_PP,
760791 /* .model = */ m,
761792 /* .n_prompt = */ n_prompt,
762793 /* .n_gen = */ 0 ,
@@ -786,6 +817,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
786817 continue ;
787818 }
788819 cmd_params_instance instance = {
820+ /* .test_kind = */ TEST_KIND_TG,
789821 /* .model = */ m,
790822 /* .n_prompt = */ 0 ,
791823 /* .n_gen = */ n_gen,
@@ -815,6 +847,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
815847 continue ;
816848 }
817849 cmd_params_instance instance = {
850+ /* .test_kind = */ TEST_KIND_PG,
818851 /* .model = */ m,
819852 /* .n_prompt = */ n_pg.first ,
820853 /* .n_gen = */ n_pg.second ,
@@ -838,6 +871,36 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
838871 };
839872 instances.push_back (instance);
840873 }
874+
875+ for (const auto & n_gp : params.n_gp ) {
876+ if (n_gp.first == 0 && n_gp.second == 0 ) {
877+ continue ;
878+ }
879+ cmd_params_instance instance = {
880+ /* .test_kind = */ TEST_KIND_GP,
881+ /* .model = */ m,
882+ /* .n_prompt = */ n_gp.first ,
883+ /* .n_gen = */ n_gp.second ,
884+ /* .n_batch = */ nb,
885+ /* .n_ubatch = */ nub,
886+ /* .type_k = */ tk,
887+ /* .type_v = */ tv,
888+ /* .n_threads = */ nt,
889+ /* .cpu_mask = */ cm,
890+ /* .cpu_strict = */ cs,
891+ /* .poll = */ pl,
892+ /* .n_gpu_layers = */ nl,
893+ /* .rpc_servers = */ rpc,
894+ /* .split_mode = */ sm,
895+ /* .main_gpu = */ mg,
896+ /* .no_kv_offload= */ nkvo,
897+ /* .flash_attn = */ fa,
898+ /* .tensor_split = */ ts,
899+ /* .use_mmap = */ mmp,
900+ /* .embeddings = */ embd,
901+ };
902+ instances.push_back (instance);
903+ }
841904 }
842905 // clang-format on
843906
@@ -853,6 +916,7 @@ struct test {
853916 std::string model_type;
854917 uint64_t model_size;
855918 uint64_t model_n_params;
919+ test_kind_type test_kind;
856920 int n_batch;
857921 int n_ubatch;
858922 int n_threads;
@@ -881,6 +945,7 @@ struct test {
881945 model_type = buf;
882946 model_size = llama_model_size (lmodel);
883947 model_n_params = llama_model_n_params (lmodel);
948+ test_kind = inst.test_kind ;
884949 n_batch = inst.n_batch ;
885950 n_ubatch = inst.n_ubatch ;
886951 n_threads = inst.n_threads ;
@@ -912,7 +977,7 @@ struct test {
912977 uint64_t stdev_ns () const { return ::stdev (samples_ns); }
913978
914979 std::vector<double > get_ts () const {
915- int n_tokens = n_prompt + n_gen;
980+ int n_tokens = (test_kind == TEST_KIND_GP ? 0 : n_prompt) + n_gen;
916981 std::vector<double > ts;
917982 std::transform (samples_ns.begin (), samples_ns.end (), std::back_inserter (ts),
918983 [n_tokens](uint64_t t) { return 1e9 * n_tokens / t; });
@@ -1325,12 +1390,22 @@ struct markdown_printer : public printer {
13251390 } else if (field == " backend" ) {
13261391 value = test::get_backend ();
13271392 } else if (field == " test" ) {
1328- if (t.n_prompt > 0 && t.n_gen == 0 ) {
1329- snprintf (buf, sizeof (buf), " pp%d" , t.n_prompt );
1330- } else if (t.n_gen > 0 && t.n_prompt == 0 ) {
1331- snprintf (buf, sizeof (buf), " tg%d" , t.n_gen );
1332- } else {
1333- snprintf (buf, sizeof (buf), " pp%d+tg%d" , t.n_prompt , t.n_gen );
1393+ switch (t.test_kind ) {
1394+ case TEST_KIND_PP:
1395+ snprintf (buf, sizeof (buf), " pp%d" , t.n_prompt );
1396+ break ;
1397+ case TEST_KIND_TG:
1398+ snprintf (buf, sizeof (buf), " tg%d" , t.n_gen );
1399+ break ;
1400+ case TEST_KIND_PG:
1401+ snprintf (buf, sizeof (buf), " pp%d+tg%d" , t.n_prompt , t.n_gen );
1402+ break ;
1403+ case TEST_KIND_GP:
1404+ snprintf (buf, sizeof (buf), " tg%d@pp%d" , t.n_gen , t.n_prompt );
1405+ break ;
1406+ default :
1407+ assert (false );
1408+ exit (1 );
13341409 }
13351410 value = buf;
13361411 } else if (field == " t/s" ) {
@@ -1597,6 +1672,12 @@ int main(int argc, char ** argv) {
15971672 }
15981673 test_prompt (ctx, t.n_prompt , t.n_batch , t.n_threads );
15991674 }
1675+
1676+ // we are not interested in prompt processing time in g@p test
1677+ if (t.test_kind == TEST_KIND_GP) {
1678+ t_start = get_time_ns ();
1679+ }
1680+
16001681 if (t.n_gen > 0 ) {
16011682 if (params.progress ) {
16021683 fprintf (stderr, " llama-bench: benchmark %d/%zu: generation run %d/%d\n " , params_idx, params_count,
0 commit comments