Skip to content

Commit 6b64f74

Browse files
authored
batched-bench : fix unified KV cache handling + pp timing (ggml-org#15562)
* batched-bench : fix unified KV cache handling + pp timing * cont : run dummy token only with split KV cache
1 parent 0d5a470 commit 6b64f74

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

tools/batched-bench/batched-bench.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ int main(int argc, char ** argv) {
124124
const int tg = n_tg[i_tg];
125125
const int pl = n_pl[i_pl];
126126

127-
const int n_ctx_req = is_pp_shared ? pp + pl*tg : pl*(pp + tg);
127+
const int n_ctx_req = is_pp_shared ? (params.kv_unified ? pp : pl*pp) + pl*tg : pl*(pp + tg);
128128

129129
if (n_ctx_req > n_kv_max) {
130130
continue;
@@ -147,13 +147,24 @@ int main(int argc, char ** argv) {
147147
return 1;
148148
}
149149

150+
const auto t_pp_end = ggml_time_us();
151+
150152
if (is_pp_shared) {
151153
for (int32_t i = 1; i < pl; ++i) {
152154
llama_memory_seq_cp(mem, 0, i, -1, -1);
153155
}
154-
}
155156

156-
const auto t_pp_end = ggml_time_us();
157+
if (!params.kv_unified) {
158+
// run one dummy token to apply the memory copy
159+
common_batch_clear(batch);
160+
common_batch_add(batch, get_token_rand(), pp + 0, { 0 }, true);
161+
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
162+
LOG_ERR("%s: llama_decode() failed\n", __func__);
163+
return 1;
164+
}
165+
llama_memory_seq_rm(mem, 0, pp, -1);
166+
}
167+
}
157168

158169
const auto t_tg_start = ggml_time_us();
159170

0 commit comments

Comments
 (0)