@@ -124,7 +124,7 @@ int main(int argc, char ** argv) {
124
124
const int tg = n_tg[i_tg];
125
125
const int pl = n_pl[i_pl];
126
126
127
- const int n_ctx_req = is_pp_shared ? pp + pl*tg : pl*(pp + tg);
127
+ const int n_ctx_req = is_pp_shared ? (params. kv_unified ? pp : pl*pp) + pl*tg : pl*(pp + tg);
128
128
129
129
if (n_ctx_req > n_kv_max) {
130
130
continue ;
@@ -147,13 +147,24 @@ int main(int argc, char ** argv) {
147
147
return 1 ;
148
148
}
149
149
150
+ const auto t_pp_end = ggml_time_us ();
151
+
150
152
if (is_pp_shared) {
151
153
for (int32_t i = 1 ; i < pl; ++i) {
152
154
llama_memory_seq_cp (mem, 0 , i, -1 , -1 );
153
155
}
154
- }
155
156
156
- const auto t_pp_end = ggml_time_us ();
157
+ if (!params.kv_unified ) {
158
+ // run one dummy token to apply the memory copy
159
+ common_batch_clear (batch);
160
+ common_batch_add (batch, get_token_rand (), pp + 0 , { 0 }, true );
161
+ if (!decode_helper (ctx, batch, ctx_params.n_batch )) {
162
+ LOG_ERR (" %s: llama_decode() failed\n " , __func__);
163
+ return 1 ;
164
+ }
165
+ llama_memory_seq_rm (mem, 0 , pp, -1 );
166
+ }
167
+ }
157
168
158
169
const auto t_tg_start = ggml_time_us ();
159
170
0 commit comments