@@ -71,7 +71,7 @@ int main(int argc, char ** argv) {
7171 llama_batch batch = llama_batch_init (n_kv_max, 0 , 1 );
7272
7373 // decode in batches of ctx_params.n_batch tokens
74- auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
74+ auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch, bool synchronize ) {
7575 for (int32_t i = 0 ; i < (int32_t ) batch.n_tokens ; i += n_batch) {
7676 const int32_t n_tokens = std::min (n_batch, (int32_t ) (batch.n_tokens - i));
7777
@@ -91,7 +91,9 @@ int main(int argc, char ** argv) {
9191 return false ;
9292 }
9393
94- llama_synchronize (ctx);
94+ if (synchronize) {
95+ llama_synchronize (ctx);
96+ }
9597 }
9698
9799 return true ;
@@ -103,7 +105,7 @@ int main(int argc, char ** argv) {
103105 common_batch_add (batch, get_token_rand (), i, { 0 }, false );
104106 }
105107
106- if (!decode_helper (ctx, batch, ctx_params.n_batch )) {
108+ if (!decode_helper (ctx, batch, ctx_params.n_batch , true )) {
107109 LOG_ERR (" %s: llama_decode() failed\n " , __func__);
108110 return 1 ;
109111 }
@@ -138,15 +140,17 @@ int main(int argc, char ** argv) {
138140 }
139141 }
140142
141- const auto t_pp_start = ggml_time_us ();
142-
143143 llama_memory_clear (mem, false );
144144
145- if (!decode_helper (ctx, batch, ctx_params.n_batch )) {
145+ const auto t_pp_start = ggml_time_us ();
146+
147+ if (!decode_helper (ctx, batch, ctx_params.n_batch , false )) {
146148 LOG_ERR (" %s: llama_decode() failed\n " , __func__);
147149 return 1 ;
148150 }
149151
152+ llama_synchronize (ctx);
153+
150154 const auto t_pp_end = ggml_time_us ();
151155
152156 if (is_pp_shared) {
@@ -158,7 +162,7 @@ int main(int argc, char ** argv) {
158162 // run one dummy token to apply the memory copy
159163 common_batch_clear (batch);
160164 common_batch_add (batch, get_token_rand (), pp + 0 , { 0 }, true );
161- if (!decode_helper (ctx, batch, ctx_params.n_batch )) {
165+ if (!decode_helper (ctx, batch, ctx_params.n_batch , true )) {
162166 LOG_ERR (" %s: llama_decode() failed\n " , __func__);
163167 return 1 ;
164168 }
@@ -175,7 +179,7 @@ int main(int argc, char ** argv) {
175179 common_batch_add (batch, get_token_rand (), pp + i, { j }, true );
176180 }
177181
178- if (!decode_helper (ctx, batch, ctx_params.n_batch )) {
182+ if (!decode_helper (ctx, batch, ctx_params.n_batch , true )) {
179183 LOG_ERR (" %s: llama_decode() failed\n " , __func__);
180184 return 1 ;
181185 }
0 commit comments