@@ -160,6 +160,7 @@ struct cmd_params {
160160 std::vector<int > n_prompt;
161161 std::vector<int > n_gen;
162162 std::vector<std::pair<int , int >> n_pg;
163+ std::vector<int > n_depth;
163164 std::vector<int > n_batch;
164165 std::vector<int > n_ubatch;
165166 std::vector<ggml_type> type_k;
@@ -192,6 +193,7 @@ static const cmd_params cmd_params_defaults = {
192193 /* n_prompt */ { 512 },
193194 /* n_gen */ { 128 },
194195 /* n_pg */ {},
196+ /* n_depth */ { 0 },
195197 /* n_batch */ { 2048 },
196198 /* n_ubatch */ { 512 },
197199 /* type_k */ { GGML_TYPE_F16 },
@@ -230,6 +232,8 @@ static void print_usage(int /* argc */, char ** argv) {
230232 printf (" -n, --n-gen <n> (default: %s)\n " , join (cmd_params_defaults.n_gen , " ," ).c_str ());
231233 printf (" -pg <pp,tg> (default: %s)\n " ,
232234 join (transform_to_str (cmd_params_defaults.n_pg , pair_str), " ," ).c_str ());
235+ printf (" -d, --depth <n> (default: %s)\n " ,
236+ join (cmd_params_defaults.n_depth , " ," ).c_str ());
233237 printf (" -b, --batch-size <n> (default: %s)\n " ,
234238 join (cmd_params_defaults.n_batch , " ," ).c_str ());
235239 printf (" -ub, --ubatch-size <n> (default: %s)\n " ,
@@ -366,6 +370,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
366370 break ;
367371 }
368372 params.n_pg .push_back ({ std::stoi (p[0 ]), std::stoi (p[1 ]) });
373+ } else if (arg == " -d" || arg == " --depth" ) {
374+ if (++i >= argc) {
375+ invalid_param = true ;
376+ break ;
377+ }
378+ auto p = string_split<int >(argv[i], split_delim);
379+ params.n_depth .insert (params.n_depth .end (), p.begin (), p.end ());
369380 } else if (arg == " -b" || arg == " --batch-size" ) {
370381 if (++i >= argc) {
371382 invalid_param = true ;
@@ -615,6 +626,9 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
615626 if (params.n_pg .empty ()) {
616627 params.n_pg = cmd_params_defaults.n_pg ;
617628 }
629+ if (params.n_depth .empty ()) {
630+ params.n_depth = cmd_params_defaults.n_depth ;
631+ }
618632 if (params.n_batch .empty ()) {
619633 params.n_batch = cmd_params_defaults.n_batch ;
620634 }
@@ -674,6 +688,7 @@ struct cmd_params_instance {
674688 std::string model;
675689 int n_prompt;
676690 int n_gen;
691+ int n_depth;
677692 int n_batch;
678693 int n_ubatch;
679694 ggml_type type_k;
@@ -745,7 +760,7 @@ struct cmd_params_instance {
745760 llama_context_params to_llama_cparams () const {
746761 llama_context_params cparams = llama_context_default_params ();
747762
748- cparams.n_ctx = n_prompt + n_gen;
763+ cparams.n_ctx = n_prompt + n_gen + n_depth ;
749764 cparams.n_batch = n_batch;
750765 cparams.n_ubatch = n_ubatch;
751766 cparams.type_k = type_k;
@@ -780,6 +795,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
780795 for (const auto & nt : params.n_threads )
781796 for (const auto & cm : params.cpu_mask )
782797 for (const auto & cs : params.cpu_strict )
798+ for (const auto & nd : params.n_depth )
783799 for (const auto & pl : params.poll ) {
784800 for (const auto & n_prompt : params.n_prompt ) {
785801 if (n_prompt == 0 ) {
@@ -789,6 +805,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
789805 /* .model = */ m,
790806 /* .n_prompt = */ n_prompt,
791807 /* .n_gen = */ 0 ,
808+ /* .n_depth = */ nd,
792809 /* .n_batch = */ nb,
793810 /* .n_ubatch = */ nub,
794811 /* .type_k = */ tk,
@@ -818,6 +835,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
818835 /* .model = */ m,
819836 /* .n_prompt = */ 0 ,
820837 /* .n_gen = */ n_gen,
838+ /* .n_depth = */ nd,
821839 /* .n_batch = */ nb,
822840 /* .n_ubatch = */ nub,
823841 /* .type_k = */ tk,
@@ -847,6 +865,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
847865 /* .model = */ m,
848866 /* .n_prompt = */ n_pg.first ,
849867 /* .n_gen = */ n_pg.second ,
868+ /* .n_depth = */ nd,
850869 /* .n_batch = */ nb,
851870 /* .n_ubatch = */ nub,
852871 /* .type_k = */ tk,
@@ -900,6 +919,7 @@ struct test {
900919 bool embeddings;
901920 int n_prompt;
902921 int n_gen;
922+ int n_depth;
903923 std::string test_time;
904924 std::vector<uint64_t > samples_ns;
905925
@@ -931,6 +951,7 @@ struct test {
931951 embeddings = inst.embeddings ;
932952 n_prompt = inst.n_prompt ;
933953 n_gen = inst.n_gen ;
954+ n_depth = inst.n_depth ;
934955 // RFC 3339 date-time format
935956 time_t t = time (NULL );
936957 std::strftime (buf, sizeof (buf), " %FT%TZ" , gmtime (&t));
@@ -1362,6 +1383,9 @@ struct markdown_printer : public printer {
13621383 } else {
13631384 snprintf (buf, sizeof (buf), " pp%d+tg%d" , t.n_prompt , t.n_gen );
13641385 }
1386+ if (t.n_depth > 0 ) {
1387+ snprintf (buf, sizeof (buf), " %s @ d%d" , buf, t.n_depth );
1388+ }
13651389 value = buf;
13661390 } else if (field == " t/s" ) {
13671391 snprintf (buf, sizeof (buf), " %.2f ± %.2f" , t.avg_ts (), t.stdev_ts ());
@@ -1603,6 +1627,12 @@ int main(int argc, char ** argv) {
16031627 llama_attach_threadpool (ctx, threadpool, NULL );
16041628
16051629 // warmup run
1630+ // if (t.n_depth > 0) {
1631+ // if (params.progress) {
1632+ // fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup depth run\n", params_idx, params_count);
1633+ // }
1634+ // test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads);
1635+ // }
16061636 if (t.n_prompt > 0 ) {
16071637 if (params.progress ) {
16081638 fprintf (stderr, " llama-bench: benchmark %d/%zu: warmup prompt run\n " , params_idx, params_count);
@@ -1620,6 +1650,14 @@ int main(int argc, char ** argv) {
16201650 for (int i = 0 ; i < params.reps ; i++) {
16211651 llama_kv_self_clear (ctx);
16221652
1653+ if (t.n_depth > 0 ) {
1654+ if (params.progress ) {
1655+ fprintf (stderr, " llama-bench: benchmark %d/%zu: depth run %d/%d\n " , params_idx, params_count,
1656+ i + 1 , params.reps );
1657+ }
1658+ test_prompt (ctx, t.n_depth , t.n_batch , t.n_threads );
1659+ }
1660+
16231661 uint64_t t_start = get_time_ns ();
16241662
16251663 if (t.n_prompt > 0 ) {
0 commit comments