Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 88 additions & 7 deletions examples/llama-bench/llama-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ struct cmd_params {
std::vector<int> n_prompt;
std::vector<int> n_gen;
std::vector<std::pair<int, int>> n_pg;
std::vector<std::pair<int, int>> n_gp;
std::vector<int> n_batch;
std::vector<int> n_ubatch;
std::vector<ggml_type> type_k;
Expand Down Expand Up @@ -192,6 +193,7 @@ static const cmd_params cmd_params_defaults = {
/* n_prompt */ { 512 },
/* n_gen */ { 128 },
/* n_pg */ {},
/* n_gp */ {},
/* n_batch */ { 2048 },
/* n_ubatch */ { 512 },
/* type_k */ { GGML_TYPE_F16 },
Expand Down Expand Up @@ -230,6 +232,8 @@ static void print_usage(int /* argc */, char ** argv) {
printf(" -n, --n-gen <n> (default: %s)\n", join(cmd_params_defaults.n_gen, ",").c_str());
printf(" -pg <pp,tg> (default: %s)\n",
join(transform_to_str(cmd_params_defaults.n_pg, pair_str), ",").c_str());
printf(" -gp <pp,tg> (default: %s)\n",
join(transform_to_str(cmd_params_defaults.n_gp, pair_str), ",").c_str());
printf(" -b, --batch-size <n> (default: %s)\n",
join(cmd_params_defaults.n_batch, ",").c_str());
printf(" -ub, --ubatch-size <n> (default: %s)\n",
Expand Down Expand Up @@ -366,6 +370,17 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
break;
}
params.n_pg.push_back({ std::stoi(p[0]), std::stoi(p[1]) });
} else if (arg == "-gp") {
if (++i >= argc) {
invalid_param = true;
break;
}
auto p = string_split<std::string>(argv[i], ',');
if (p.size() != 2) {
invalid_param = true;
break;
}
params.n_gp.push_back({ std::stoi(p[0]), std::stoi(p[1]) });
} else if (arg == "-b" || arg == "--batch-size") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -615,6 +630,9 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
if (params.n_pg.empty()) {
params.n_pg = cmd_params_defaults.n_pg;
}
if (params.n_gp.empty()) {
params.n_gp = cmd_params_defaults.n_gp;
}
if (params.n_batch.empty()) {
params.n_batch = cmd_params_defaults.n_batch;
}
Expand Down Expand Up @@ -670,7 +688,19 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
return params;
}

enum test_kind_type {
// measure mean prompt processing rate without token generation
TEST_KIND_PP,
// measure mean token generation rate without prompt processing
TEST_KIND_TG,
// measure mean prompt processing and token generation rate
TEST_KIND_PG,
// measure mean token generation rate after processing prompt of given length
TEST_KIND_GP,
};

struct cmd_params_instance {
test_kind_type test_kind;
std::string model;
int n_prompt;
int n_gen;
Expand Down Expand Up @@ -757,6 +787,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
continue;
}
cmd_params_instance instance = {
/* .test_kind = */ TEST_KIND_PP,
/* .model = */ m,
/* .n_prompt = */ n_prompt,
/* .n_gen = */ 0,
Expand Down Expand Up @@ -786,6 +817,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
continue;
}
cmd_params_instance instance = {
/* .test_kind = */ TEST_KIND_TG,
/* .model = */ m,
/* .n_prompt = */ 0,
/* .n_gen = */ n_gen,
Expand Down Expand Up @@ -815,6 +847,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
continue;
}
cmd_params_instance instance = {
/* .test_kind = */ TEST_KIND_PG,
/* .model = */ m,
/* .n_prompt = */ n_pg.first,
/* .n_gen = */ n_pg.second,
Expand All @@ -838,6 +871,36 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
};
instances.push_back(instance);
}

for (const auto & n_gp : params.n_gp) {
if (n_gp.first == 0 && n_gp.second == 0) {
continue;
}
cmd_params_instance instance = {
/* .test_kind = */ TEST_KIND_GP,
/* .model = */ m,
/* .n_prompt = */ n_gp.first,
/* .n_gen = */ n_gp.second,
/* .n_batch = */ nb,
/* .n_ubatch = */ nub,
/* .type_k = */ tk,
/* .type_v = */ tv,
/* .n_threads = */ nt,
/* .cpu_mask = */ cm,
/* .cpu_strict = */ cs,
/* .poll = */ pl,
/* .n_gpu_layers = */ nl,
/* .rpc_servers = */ rpc,
/* .split_mode = */ sm,
/* .main_gpu = */ mg,
/* .no_kv_offload= */ nkvo,
/* .flash_attn = */ fa,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
};
instances.push_back(instance);
}
}
// clang-format on

Expand All @@ -853,6 +916,7 @@ struct test {
std::string model_type;
uint64_t model_size;
uint64_t model_n_params;
test_kind_type test_kind;
int n_batch;
int n_ubatch;
int n_threads;
Expand Down Expand Up @@ -881,6 +945,7 @@ struct test {
model_type = buf;
model_size = llama_model_size(lmodel);
model_n_params = llama_model_n_params(lmodel);
test_kind = inst.test_kind;
n_batch = inst.n_batch;
n_ubatch = inst.n_ubatch;
n_threads = inst.n_threads;
Expand Down Expand Up @@ -912,7 +977,7 @@ struct test {
uint64_t stdev_ns() const { return ::stdev(samples_ns); }

std::vector<double> get_ts() const {
int n_tokens = n_prompt + n_gen;
int n_tokens = (test_kind == TEST_KIND_GP ? 0 : n_prompt) + n_gen;
std::vector<double> ts;
std::transform(samples_ns.begin(), samples_ns.end(), std::back_inserter(ts),
[n_tokens](uint64_t t) { return 1e9 * n_tokens / t; });
Expand Down Expand Up @@ -1325,12 +1390,22 @@ struct markdown_printer : public printer {
} else if (field == "backend") {
value = test::get_backend();
} else if (field == "test") {
if (t.n_prompt > 0 && t.n_gen == 0) {
snprintf(buf, sizeof(buf), "pp%d", t.n_prompt);
} else if (t.n_gen > 0 && t.n_prompt == 0) {
snprintf(buf, sizeof(buf), "tg%d", t.n_gen);
} else {
snprintf(buf, sizeof(buf), "pp%d+tg%d", t.n_prompt, t.n_gen);
switch (t.test_kind) {
case TEST_KIND_PP:
snprintf(buf, sizeof(buf), "pp%d", t.n_prompt);
break;
case TEST_KIND_TG:
snprintf(buf, sizeof(buf), "tg%d", t.n_gen);
break;
case TEST_KIND_PG:
snprintf(buf, sizeof(buf), "pp%d+tg%d", t.n_prompt, t.n_gen);
break;
case TEST_KIND_GP:
snprintf(buf, sizeof(buf), "tg%d@pp%d", t.n_gen, t.n_prompt);
break;
default:
assert(false);
exit(1);
}
value = buf;
} else if (field == "t/s") {
Expand Down Expand Up @@ -1597,6 +1672,12 @@ int main(int argc, char ** argv) {
}
test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
}

// we are not interested in prompt processing time in g@p test
if (t.test_kind == TEST_KIND_GP) {
t_start = get_time_ns();
}

if (t.n_gen > 0) {
if (params.progress) {
fprintf(stderr, "llama-bench: benchmark %d/%zu: generation run %d/%d\n", params_idx, params_count,
Expand Down
Loading