@@ -71,7 +71,7 @@ int main(int argc, char ** argv) {
71
71
llama_batch batch = llama_batch_init (n_kv_max, 0 , 1 );
72
72
73
73
// decode in batches of ctx_params.n_batch tokens
74
- auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
74
+ auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch, bool synchronize ) {
75
75
for (int32_t i = 0 ; i < (int32_t ) batch.n_tokens ; i += n_batch) {
76
76
const int32_t n_tokens = std::min (n_batch, (int32_t ) (batch.n_tokens - i));
77
77
@@ -91,7 +91,9 @@ int main(int argc, char ** argv) {
91
91
return false ;
92
92
}
93
93
94
- llama_synchronize (ctx);
94
+ if (synchronize) {
95
+ llama_synchronize (ctx);
96
+ }
95
97
}
96
98
97
99
return true ;
@@ -103,7 +105,7 @@ int main(int argc, char ** argv) {
103
105
common_batch_add (batch, get_token_rand (), i, { 0 }, false );
104
106
}
105
107
106
- if (!decode_helper (ctx, batch, ctx_params.n_batch )) {
108
+ if (!decode_helper (ctx, batch, ctx_params.n_batch , true )) {
107
109
LOG_ERR (" %s: llama_decode() failed\n " , __func__);
108
110
return 1 ;
109
111
}
@@ -138,15 +140,17 @@ int main(int argc, char ** argv) {
138
140
}
139
141
}
140
142
141
- const auto t_pp_start = ggml_time_us ();
142
-
143
143
llama_memory_clear (mem, false );
144
144
145
- if (!decode_helper (ctx, batch, ctx_params.n_batch )) {
145
+ const auto t_pp_start = ggml_time_us ();
146
+
147
+ if (!decode_helper (ctx, batch, ctx_params.n_batch , false )) {
146
148
LOG_ERR (" %s: llama_decode() failed\n " , __func__);
147
149
return 1 ;
148
150
}
149
151
152
+ llama_synchronize (ctx);
153
+
150
154
const auto t_pp_end = ggml_time_us ();
151
155
152
156
if (is_pp_shared) {
@@ -158,7 +162,7 @@ int main(int argc, char ** argv) {
158
162
// run one dummy token to apply the memory copy
159
163
common_batch_clear (batch);
160
164
common_batch_add (batch, get_token_rand (), pp + 0 , { 0 }, true );
161
- if (!decode_helper (ctx, batch, ctx_params.n_batch )) {
165
+ if (!decode_helper (ctx, batch, ctx_params.n_batch , true )) {
162
166
LOG_ERR (" %s: llama_decode() failed\n " , __func__);
163
167
return 1 ;
164
168
}
@@ -175,7 +179,7 @@ int main(int argc, char ** argv) {
175
179
common_batch_add (batch, get_token_rand (), pp + i, { j }, true );
176
180
}
177
181
178
- if (!decode_helper (ctx, batch, ctx_params.n_batch )) {
182
+ if (!decode_helper (ctx, batch, ctx_params.n_batch , true )) {
179
183
LOG_ERR (" %s: llama_decode() failed\n " , __func__);
180
184
return 1 ;
181
185
}
0 commit comments