Skip to content

Commit a885dcf

Browse files
authored
batched-bench : fix llama_synchronize usage during prompt processing (ggml-org#15835)
ggml-ci
1 parent 663027f commit a885dcf

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

tools/batched-bench/batched-bench.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ int main(int argc, char ** argv) {
7171
llama_batch batch = llama_batch_init(n_kv_max, 0, 1);
7272

7373
// decode in batches of ctx_params.n_batch tokens
74-
auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
74+
auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch, bool synchronize) {
7575
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
7676
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
7777

@@ -91,7 +91,9 @@ int main(int argc, char ** argv) {
9191
return false;
9292
}
9393

94-
llama_synchronize(ctx);
94+
if (synchronize) {
95+
llama_synchronize(ctx);
96+
}
9597
}
9698

9799
return true;
@@ -103,7 +105,7 @@ int main(int argc, char ** argv) {
103105
common_batch_add(batch, get_token_rand(), i, { 0 }, false);
104106
}
105107

106-
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
108+
if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) {
107109
LOG_ERR("%s: llama_decode() failed\n", __func__);
108110
return 1;
109111
}
@@ -138,15 +140,17 @@ int main(int argc, char ** argv) {
138140
}
139141
}
140142

141-
const auto t_pp_start = ggml_time_us();
142-
143143
llama_memory_clear(mem, false);
144144

145-
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
145+
const auto t_pp_start = ggml_time_us();
146+
147+
if (!decode_helper(ctx, batch, ctx_params.n_batch, false)) {
146148
LOG_ERR("%s: llama_decode() failed\n", __func__);
147149
return 1;
148150
}
149151

152+
llama_synchronize(ctx);
153+
150154
const auto t_pp_end = ggml_time_us();
151155

152156
if (is_pp_shared) {
@@ -158,7 +162,7 @@ int main(int argc, char ** argv) {
158162
// run one dummy token to apply the memory copy
159163
common_batch_clear(batch);
160164
common_batch_add(batch, get_token_rand(), pp + 0, { 0 }, true);
161-
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
165+
if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) {
162166
LOG_ERR("%s: llama_decode() failed\n", __func__);
163167
return 1;
164168
}
@@ -175,7 +179,7 @@ int main(int argc, char ** argv) {
175179
common_batch_add(batch, get_token_rand(), pp + i, { j }, true);
176180
}
177181

178-
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
182+
if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) {
179183
LOG_ERR("%s: llama_decode() failed\n", __func__);
180184
return 1;
181185
}

0 commit comments

Comments
 (0)