Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 87 additions & 27 deletions examples/llama-bench/llama-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,9 @@ struct test {
int n_prompt;
int n_gen;
std::string test_time;
std::vector<uint64_t> samples_ns;
std::vector<uint64_t> samples_e2e_ns; // e2e latency including prompt processing + token generation
std::vector<uint64_t> samples_prompt_ns; // prompt processing latency
std::vector<uint64_t> samples_gen_ns; // token generation latency

test(const cmd_params_instance & inst, const llama_model * lmodel, const llama_context * ctx) :
cpu_info(get_cpu_info()),
Expand Down Expand Up @@ -939,21 +941,42 @@ struct test {
(void) ctx;
}

uint64_t avg_ns() const { return ::avg(samples_ns); }
uint64_t avg_e2e_ns() const { return ::avg(samples_e2e_ns); }
uint64_t avg_prompt_ns() const { return ::avg(samples_prompt_ns); }
uint64_t avg_gen_ns() const { return ::avg(samples_gen_ns); }

uint64_t stdev_ns() const { return ::stdev(samples_ns); }
uint64_t stddev_e2e_ns() const { return ::stdev(samples_e2e_ns); }
uint64_t stddev_prompt_ns() const { return ::stdev(samples_prompt_ns); }
uint64_t stddev_gen_ns() const { return ::stdev(samples_gen_ns); }

std::vector<double> get_ts() const {
int n_tokens = n_prompt + n_gen;
std::vector<double> get_ts(const std::vector<uint64_t> & samples_ns, int n_tokens) const {
if(n_tokens==0)
return {0};
std::vector<double> ts;
std::transform(samples_ns.begin(), samples_ns.end(), std::back_inserter(ts),
[n_tokens](uint64_t t) { return 1e9 * n_tokens / t; });
return ts;
}

std::vector<double> get_e2e_ts() const {
// for only prompt processing, atleast 1 token is generated
int n_tokens = n_gen==0 ? 1 : n_gen;
return get_ts(samples_e2e_ns, n_tokens);
}
std::vector<double> get_prompt_ts() const {
return get_ts(samples_prompt_ns, n_prompt);
}
std::vector<double> get_gen_ts() const {
return get_ts(samples_gen_ns, n_gen);
}

double avg_ts() const { return ::avg(get_ts()); }
double avg_e2e_ts() const { return ::avg(get_e2e_ts()); }
double avg_prompt_ts() const { return ::avg(get_prompt_ts()); }
double avg_gen_ts() const { return ::avg(get_gen_ts()); }

double stdev_ts() const { return ::stdev(get_ts()); }
double stdev_e2e_ts() const { return ::stdev(get_e2e_ts()); }
double stdev_prompt_ts() const { return ::stdev(get_prompt_ts()); }
double stdev_gen_ts() const { return ::stdev(get_gen_ts()); }

static std::string get_backend() {
std::vector<std::string> backends;
Expand All @@ -973,8 +996,10 @@ struct test {
"model_type", "model_size", "model_n_params", "n_batch", "n_ubatch", "n_threads",
"cpu_mask", "cpu_strict", "poll", "type_k", "type_v", "n_gpu_layers",
"split_mode", "main_gpu", "no_kv_offload", "flash_attn", "tensor_split", "use_mmap",
"embeddings", "n_prompt", "n_gen", "test_time", "avg_ns", "stddev_ns",
"avg_ts", "stddev_ts",
"embeddings", "n_prompt", "n_gen", "test_time",
"avg_e2e_ns", "stddev_e2e_ns", "avg_e2e_ts", "stddev_e2e_ts",
"avg_prompt_ns", "stddev_prompt_ns", "avg_prompt_ts", "stddev_prompt_ts",
"avg_gen_ns", "stddev_gen_ns", "avg_gen_ts", "stddev_gen_ts"
};
return fields;
}
Expand All @@ -984,15 +1009,16 @@ struct test {
static field_type get_field_type(const std::string & field) {
if (field == "build_number" || field == "n_batch" || field == "n_ubatch" || field == "n_threads" ||
field == "poll" || field == "model_size" || field == "model_n_params" || field == "n_gpu_layers" ||
field == "main_gpu" || field == "n_prompt" || field == "n_gen" || field == "avg_ns" ||
field == "stddev_ns") {
field == "main_gpu" || field == "n_prompt" || field == "n_gen" || field == "avg_e2e_ns" ||
field == "stddev_e2e_ns" || field == "avg_prompt_ns" || field == "stddev_prompt_ns" ||
field == "avg_gen_ns" || field == "stddev_gen_ns") {
return INT;
}
if (field == "f16_kv" || field == "no_kv_offload" || field == "cpu_strict" || field == "flash_attn" ||
field == "use_mmap" || field == "embeddings") {
return BOOL;
}
if (field == "avg_ts" || field == "stddev_ts") {
if (field == "avg_e2e_ts" || field == "stddev_e2e_ts" || field == "avg_prompt_ts" || field == "stddev_prompt_ts" || field == "avg_gen_ts" || field == "stddev_gen_ts") {
return FLOAT;
}
return STRING;
Expand Down Expand Up @@ -1042,10 +1068,18 @@ struct test {
std::to_string(n_prompt),
std::to_string(n_gen),
test_time,
std::to_string(avg_ns()),
std::to_string(stdev_ns()),
std::to_string(avg_ts()),
std::to_string(stdev_ts()) };
std::to_string(avg_e2e_ns()),
std::to_string(stddev_e2e_ns()),
std::to_string(avg_e2e_ts()),
std::to_string(stdev_e2e_ts()),
std::to_string(avg_prompt_ns()),
std::to_string(stddev_prompt_ns()),
std::to_string(avg_prompt_ts()),
std::to_string(stdev_prompt_ts()),
std::to_string(avg_gen_ns()),
std::to_string(stddev_gen_ns()),
std::to_string(avg_gen_ts()),
std::to_string(stdev_gen_ts()) };
return values;
}

Expand Down Expand Up @@ -1153,8 +1187,12 @@ struct json_printer : public printer {
}
fprintf(fout, " {\n");
print_fields(test::get_fields(), t.get_values());
fprintf(fout, " \"samples_ns\": [ %s ],\n", join(t.samples_ns, ", ").c_str());
fprintf(fout, " \"samples_ts\": [ %s ]\n", join(t.get_ts(), ", ").c_str());
fprintf(fout, " \"samples_e2e_ns\": [ %s ],\n", join(t.samples_e2e_ns, ", ").c_str());
fprintf(fout, " \"samples_e2e_ts\": [ %s ]\n", join(t.get_e2e_ts(), ", ").c_str());
fprintf(fout, " \"samples_prompt_ns\": [ %s ],\n", join(t.samples_prompt_ns, ", ").c_str());
fprintf(fout, " \"samples_prompt_ts\": [ %s ]\n", join(t.get_prompt_ts(), ", ").c_str());
fprintf(fout, " \"samples_gen_ns\": [ %s ],\n", join(t.samples_gen_ns, ", ").c_str());
fprintf(fout, " \"samples_gen_ts\": [ %s ]\n", join(t.get_gen_ts(), ", ").c_str());
fprintf(fout, " }");
fflush(fout);
}
Expand All @@ -1173,8 +1211,12 @@ struct jsonl_printer : public printer {
void print_test(const test & t) override {
fprintf(fout, "{");
print_fields(test::get_fields(), t.get_values());
fprintf(fout, "\"samples_ns\": [ %s ],", join(t.samples_ns, ", ").c_str());
fprintf(fout, "\"samples_ts\": [ %s ]", join(t.get_ts(), ", ").c_str());
fprintf(fout, "\"samples_e2e_ns\": [ %s ],", join(t.samples_e2e_ns, ", ").c_str());
fprintf(fout, "\"samples_e2e_ts\": [ %s ]", join(t.get_e2e_ts(), ", ").c_str());
fprintf(fout, "\"samples_prompt_ns\": [ %s ],", join(t.samples_prompt_ns, ", ").c_str());
fprintf(fout, "\"samples_prompt_ts\": [ %s ]", join(t.get_prompt_ts(), ", ").c_str());
fprintf(fout, "\"samples_gen_ns\": [ %s ],", join(t.samples_gen_ns, ", ").c_str());
fprintf(fout, "\"samples_gen_ts\": [ %s ]", join(t.get_gen_ts(), ", ").c_str());
fprintf(fout, "}\n");
fflush(fout);
}
Expand All @@ -1187,7 +1229,7 @@ struct markdown_printer : public printer {
if (field == "model") {
return -30;
}
if (field == "t/s") {
if (field == "e2e t/s" || field == "prompt t/s" || field == "gen t/s") {
return 20;
}
if (field == "size" || field == "params") {
Expand Down Expand Up @@ -1314,7 +1356,9 @@ struct markdown_printer : public printer {
fields.emplace_back("embeddings");
}
fields.emplace_back("test");
fields.emplace_back("t/s");
fields.emplace_back("e2e t/s");
fields.emplace_back("prompt t/s");
fields.emplace_back("gen t/s");

fprintf(fout, "|");
for (const auto & field : fields) {
Expand Down Expand Up @@ -1363,8 +1407,14 @@ struct markdown_printer : public printer {
snprintf(buf, sizeof(buf), "pp%d+tg%d", t.n_prompt, t.n_gen);
}
value = buf;
} else if (field == "t/s") {
snprintf(buf, sizeof(buf), "%.2f ± %.2f", t.avg_ts(), t.stdev_ts());
} else if (field == "e2e t/s") {
snprintf(buf, sizeof(buf), "%.2f ± %.2f", t.avg_e2e_ts(), t.stdev_e2e_ts());
value = buf;
} else if (field == "prompt t/s") {
snprintf(buf, sizeof(buf), "%.2f ± %.2f", t.avg_prompt_ts(), t.stdev_prompt_ts());
value = buf;
} else if (field == "gen t/s") {
snprintf(buf, sizeof(buf), "%.2f ± %.2f", t.avg_gen_ts(), t.stdev_gen_ts());
value = buf;
} else if (vmap.find(field) != vmap.end()) {
value = vmap.at(field);
Expand All @@ -1374,7 +1424,7 @@ struct markdown_printer : public printer {
}

int width = get_field_width(field);
if (field == "t/s") {
if (field == "e2e t/s" || field == "prompt t/s" || field == "gen t/s") {
// HACK: the utf-8 character is 2 bytes
width += 1;
}
Expand Down Expand Up @@ -1629,6 +1679,9 @@ int main(int argc, char ** argv) {
}
test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
}

uint64_t t_gen_start = get_time_ns();

if (t.n_gen > 0) {
if (params.progress) {
fprintf(stderr, "llama-bench: benchmark %d/%zu: generation run %d/%d\n", params_idx, params_count,
Expand All @@ -1637,8 +1690,15 @@ int main(int argc, char ** argv) {
test_gen(ctx, t.n_gen, t.n_threads);
}

uint64_t t_ns = get_time_ns() - t_start;
t.samples_ns.push_back(t_ns);
uint64_t t_end = get_time_ns();

uint64_t e2e_ns = t_end - t_start;
uint64_t prompt_ns = t_gen_start - t_start;
uint64_t gen_ns = t_end - t_gen_start;

t.samples_e2e_ns.push_back(e2e_ns);
t.samples_prompt_ns.push_back(prompt_ns);
t.samples_gen_ns.push_back(gen_ns);
}

if (p) {
Expand Down