@@ -1405,9 +1405,14 @@ struct server_slot {
14051405 // if the context does not have a memory module then all embeddings have to be computed within a single ubatch
14061406 // also we cannot split if the pooling would require any past tokens
14071407 bool can_split () const {
1408+ // fprintf(stderr, "need_embd() %d\n", need_embd());
1409+ // fprintf(stderr, "llama_get_memory(ctx) %d\n", llama_get_memory(ctx) != nullptr);
1410+ // fprintf(stderr, "POOLING_TYPE check %d\n", llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST);
1411+
14081412 return
14091413 !need_embd () ||
1410- (llama_get_memory (ctx) && llama_pooling_type (ctx) == LLAMA_POOLING_TYPE_LAST);
1414+ (llama_get_memory (ctx) && llama_pooling_type (ctx) == LLAMA_POOLING_TYPE_LAST) ||
1415+ (llama_get_memory (ctx) && llama_pooling_type (ctx) == LLAMA_POOLING_TYPE_NONE); // this seems to save embeddings for whole batch?
14111416 }
14121417
14131418 bool can_batch_with (server_slot & other_slot) const {
@@ -3508,6 +3513,13 @@ struct server_context {
35083513 continue ; // continue loop of n_batch
35093514 }
35103515
3516+ for (auto & slot : slots) {
3517+ // This should only trigger on a non-empty update batch once, after prompt processing but not during token generation
3518+ if (slot.has_mtp ) {
3519+ mtp_update_kv_cache (ctx, slot.mtp_kv_update_batch , i, n_tokens);
3520+ }
3521+ }
3522+
35113523 // move the head of the batch forward with the number of tokens we just processed
35123524 i_next = i + n_tokens;
35133525
@@ -3552,9 +3564,9 @@ struct server_context {
35523564 common_sampler_accept (slot.smpl , id, true );
35533565
35543566 // This should only trigger on a non-empty update batch once, after prompt processing but not during token generation
3555- if (slot.has_mtp ) {
3556- mtp_update_kv_cache (ctx, slot.mtp_kv_update_batch );
3557- }
3567+ // if (slot.has_mtp) {
3568+ // mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch);
3569+ // }
35583570
35593571 slot.n_decoded += 1 ;
35603572
0 commit comments