@@ -71,7 +71,7 @@ int main(int argc, char ** argv) {
7171    llama_batch batch = llama_batch_init (n_kv_max, 0 , 1 );
7272
7373    //  decode in batches of ctx_params.n_batch tokens
74-     auto  decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t  n_batch) {
74+     auto  decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t  n_batch,  bool  synchronize ) {
7575        for  (int32_t  i = 0 ; i < (int32_t ) batch.n_tokens ; i += n_batch) {
7676            const  int32_t  n_tokens = std::min (n_batch, (int32_t ) (batch.n_tokens  - i));
7777
@@ -91,7 +91,9 @@ int main(int argc, char ** argv) {
9191                return  false ;
9292            }
9393
94-             llama_synchronize (ctx);
94+             if  (synchronize) {
95+                 llama_synchronize (ctx);
96+             }
9597        }
9698
9799        return  true ;
@@ -103,7 +105,7 @@ int main(int argc, char ** argv) {
103105            common_batch_add (batch, get_token_rand (), i, { 0  }, false );
104106        }
105107
106-         if  (!decode_helper (ctx, batch, ctx_params.n_batch )) {
108+         if  (!decode_helper (ctx, batch, ctx_params.n_batch ,  true )) {
107109            LOG_ERR (" %s: llama_decode() failed\n "  , __func__);
108110            return  1 ;
109111        }
@@ -138,15 +140,17 @@ int main(int argc, char ** argv) {
138140                    }
139141                }
140142
141-                 const  auto  t_pp_start = ggml_time_us ();
142- 
143143                llama_memory_clear (mem, false );
144144
145-                 if  (!decode_helper (ctx, batch, ctx_params.n_batch )) {
145+                 const  auto  t_pp_start = ggml_time_us ();
146+ 
147+                 if  (!decode_helper (ctx, batch, ctx_params.n_batch , false )) {
146148                    LOG_ERR (" %s: llama_decode() failed\n "  , __func__);
147149                    return  1 ;
148150                }
149151
152+                 llama_synchronize (ctx);
153+ 
150154                const  auto  t_pp_end = ggml_time_us ();
151155
152156                if  (is_pp_shared) {
@@ -158,7 +162,7 @@ int main(int argc, char ** argv) {
158162                        //  run one dummy token to apply the memory copy
159163                        common_batch_clear (batch);
160164                        common_batch_add (batch, get_token_rand (), pp + 0 , { 0  }, true );
161-                         if  (!decode_helper (ctx, batch, ctx_params.n_batch )) {
165+                         if  (!decode_helper (ctx, batch, ctx_params.n_batch ,  true )) {
162166                            LOG_ERR (" %s: llama_decode() failed\n "  , __func__);
163167                            return  1 ;
164168                        }
@@ -175,7 +179,7 @@ int main(int argc, char ** argv) {
175179                        common_batch_add (batch, get_token_rand (), pp + i, { j }, true );
176180                    }
177181
178-                     if  (!decode_helper (ctx, batch, ctx_params.n_batch )) {
182+                     if  (!decode_helper (ctx, batch, ctx_params.n_batch ,  true )) {
179183                        LOG_ERR (" %s: llama_decode() failed\n "  , __func__);
180184                        return  1 ;
181185                    }
0 commit comments