Skip to content

Commit 8cf427f

Browse files
add depth param
1 parent 2016f07 commit 8cf427f

File tree

1 file changed

+39
-1
lines changed

1 file changed

+39
-1
lines changed

examples/llama-bench/llama-bench.cpp

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ struct cmd_params {
160160
std::vector<int> n_prompt;
161161
std::vector<int> n_gen;
162162
std::vector<std::pair<int, int>> n_pg;
163+
std::vector<int> n_depth;
163164
std::vector<int> n_batch;
164165
std::vector<int> n_ubatch;
165166
std::vector<ggml_type> type_k;
@@ -192,6 +193,7 @@ static const cmd_params cmd_params_defaults = {
192193
/* n_prompt */ { 512 },
193194
/* n_gen */ { 128 },
194195
/* n_pg */ {},
196+
/* n_depth */ { 0 },
195197
/* n_batch */ { 2048 },
196198
/* n_ubatch */ { 512 },
197199
/* type_k */ { GGML_TYPE_F16 },
@@ -230,6 +232,8 @@ static void print_usage(int /* argc */, char ** argv) {
230232
printf(" -n, --n-gen <n> (default: %s)\n", join(cmd_params_defaults.n_gen, ",").c_str());
231233
printf(" -pg <pp,tg> (default: %s)\n",
232234
join(transform_to_str(cmd_params_defaults.n_pg, pair_str), ",").c_str());
235+
printf(" -d, --depth <n> (default: %s)\n",
236+
join(cmd_params_defaults.n_depth, ",").c_str());
233237
printf(" -b, --batch-size <n> (default: %s)\n",
234238
join(cmd_params_defaults.n_batch, ",").c_str());
235239
printf(" -ub, --ubatch-size <n> (default: %s)\n",
@@ -366,6 +370,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
366370
break;
367371
}
368372
params.n_pg.push_back({ std::stoi(p[0]), std::stoi(p[1]) });
373+
} else if (arg == "-d" || arg == "--depth") {
374+
if (++i >= argc) {
375+
invalid_param = true;
376+
break;
377+
}
378+
auto p = string_split<int>(argv[i], split_delim);
379+
params.n_depth.insert(params.n_depth.end(), p.begin(), p.end());
369380
} else if (arg == "-b" || arg == "--batch-size") {
370381
if (++i >= argc) {
371382
invalid_param = true;
@@ -615,6 +626,9 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
615626
if (params.n_pg.empty()) {
616627
params.n_pg = cmd_params_defaults.n_pg;
617628
}
629+
if (params.n_depth.empty()) {
630+
params.n_depth = cmd_params_defaults.n_depth;
631+
}
618632
if (params.n_batch.empty()) {
619633
params.n_batch = cmd_params_defaults.n_batch;
620634
}
@@ -674,6 +688,7 @@ struct cmd_params_instance {
674688
std::string model;
675689
int n_prompt;
676690
int n_gen;
691+
int n_depth;
677692
int n_batch;
678693
int n_ubatch;
679694
ggml_type type_k;
@@ -745,7 +760,7 @@ struct cmd_params_instance {
745760
llama_context_params to_llama_cparams() const {
746761
llama_context_params cparams = llama_context_default_params();
747762

748-
cparams.n_ctx = n_prompt + n_gen;
763+
cparams.n_ctx = n_prompt + n_gen + n_depth;
749764
cparams.n_batch = n_batch;
750765
cparams.n_ubatch = n_ubatch;
751766
cparams.type_k = type_k;
@@ -780,6 +795,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
780795
for (const auto & nt : params.n_threads)
781796
for (const auto & cm : params.cpu_mask)
782797
for (const auto & cs : params.cpu_strict)
798+
for (const auto & nd : params.n_depth)
783799
for (const auto & pl : params.poll) {
784800
for (const auto & n_prompt : params.n_prompt) {
785801
if (n_prompt == 0) {
@@ -789,6 +805,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
789805
/* .model = */ m,
790806
/* .n_prompt = */ n_prompt,
791807
/* .n_gen = */ 0,
808+
/* .n_depth = */ nd,
792809
/* .n_batch = */ nb,
793810
/* .n_ubatch = */ nub,
794811
/* .type_k = */ tk,
@@ -818,6 +835,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
818835
/* .model = */ m,
819836
/* .n_prompt = */ 0,
820837
/* .n_gen = */ n_gen,
838+
/* .n_depth = */ nd,
821839
/* .n_batch = */ nb,
822840
/* .n_ubatch = */ nub,
823841
/* .type_k = */ tk,
@@ -847,6 +865,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
847865
/* .model = */ m,
848866
/* .n_prompt = */ n_pg.first,
849867
/* .n_gen = */ n_pg.second,
868+
/* .n_depth = */ nd,
850869
/* .n_batch = */ nb,
851870
/* .n_ubatch = */ nub,
852871
/* .type_k = */ tk,
@@ -900,6 +919,7 @@ struct test {
900919
bool embeddings;
901920
int n_prompt;
902921
int n_gen;
922+
int n_depth;
903923
std::string test_time;
904924
std::vector<uint64_t> samples_ns;
905925

@@ -931,6 +951,7 @@ struct test {
931951
embeddings = inst.embeddings;
932952
n_prompt = inst.n_prompt;
933953
n_gen = inst.n_gen;
954+
n_depth = inst.n_depth;
934955
// RFC 3339 date-time format
935956
time_t t = time(NULL);
936957
std::strftime(buf, sizeof(buf), "%FT%TZ", gmtime(&t));
@@ -1362,6 +1383,9 @@ struct markdown_printer : public printer {
13621383
} else {
13631384
snprintf(buf, sizeof(buf), "pp%d+tg%d", t.n_prompt, t.n_gen);
13641385
}
1386+
if (t.n_depth > 0) {
1387+
snprintf(buf, sizeof(buf), "%s @ d%d", buf, t.n_depth);
1388+
}
13651389
value = buf;
13661390
} else if (field == "t/s") {
13671391
snprintf(buf, sizeof(buf), "%.2f ± %.2f", t.avg_ts(), t.stdev_ts());
@@ -1603,6 +1627,12 @@ int main(int argc, char ** argv) {
16031627
llama_attach_threadpool(ctx, threadpool, NULL);
16041628

16051629
// warmup run
1630+
// if (t.n_depth > 0) {
1631+
// if (params.progress) {
1632+
// fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup depth run\n", params_idx, params_count);
1633+
// }
1634+
// test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads);
1635+
// }
16061636
if (t.n_prompt > 0) {
16071637
if (params.progress) {
16081638
fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup prompt run\n", params_idx, params_count);
@@ -1620,6 +1650,14 @@ int main(int argc, char ** argv) {
16201650
for (int i = 0; i < params.reps; i++) {
16211651
llama_kv_self_clear(ctx);
16221652

1653+
if (t.n_depth > 0) {
1654+
if (params.progress) {
1655+
fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d\n", params_idx, params_count,
1656+
i + 1, params.reps);
1657+
}
1658+
test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads);
1659+
}
1660+
16231661
uint64_t t_start = get_time_ns();
16241662

16251663
if (t.n_prompt > 0) {

0 commit comments

Comments
 (0)