@@ -4341,13 +4341,19 @@ size_t gpttype_calc_new_state_kv()
43414341 }
43424342 if (file_format == FileFormat::GGUF_GENERIC)
43434343 {
4344- return llama_state_get_size (llama_ctx_v4);
4344+ size_t s1 = llama_state_get_size (llama_ctx_v4);
4345+ if (draft_ctx)
4346+ {
4347+ size_t s2 = llama_state_get_size (draft_ctx);
4348+ s1 += s2;
4349+ }
4350+ return s1;
43454351 }
43464352 return 0 ;
43474353}
43484354size_t gpttype_calc_old_state_kv (int slot)
43494355{
4350- return savestates[slot].current_savestate_size ;
4356+ return savestates[slot].current_savestate_size + savestates[slot]. current_draft_savestate_size ;
43514357}
43524358size_t gpttype_calc_old_state_tokencount (int slot)
43534359{
@@ -4365,30 +4371,54 @@ size_t gpttype_save_state_kv(int slot)
43654371 }
43664372 if (file_format == FileFormat::GGUF_GENERIC)
43674373 {
4374+ size_t totalbytes = 0 ;
43684375 if (!savestates[slot].current_savestate_buffer .empty ()) { // JIT free
43694376 savestates[slot].current_savestate_buffer .clear ();
4377+ savestates[slot].current_draft_savestate_buffer .clear ();
43704378 savestates[slot].savestate_context_tokens .clear ();
43714379 savestates[slot].current_savestate_size = 0 ;
4380+ savestates[slot].current_draft_savestate_size = 0 ;
43724381 }
43734382 size_t newsize = llama_state_get_size (llama_ctx_v4);
43744383 try {
43754384 if (savestates[slot].current_savestate_buffer .capacity () < newsize + 512 ) {
4376- savestates[slot].current_savestate_buffer = std::vector<uint8_t >(newsize + 512 );
4385+ savestates[slot].current_savestate_buffer = std::vector<uint8_t >(newsize + 512 ); // add some padding. May throw std::bad_alloc
43774386 } else {
43784387 savestates[slot].current_savestate_buffer .resize (newsize + 512 );
43794388 }
4380- savestates[slot].current_savestate_buffer .resize (newsize + 512 ); // add some padding. May throw std::bad_alloc
43814389 } catch (const std::bad_alloc&) {
43824390 fprintf (stderr, " KV Save State: Failed to allocate %zu bytes.\n " , newsize + 512 );
43834391 return 0 ;
43844392 }
43854393 auto res = llama_state_get_data (llama_ctx_v4, savestates[slot].current_savestate_buffer .data (), newsize);
43864394 if (res > 0 ) {
4395+ totalbytes += res;
43874396 savestates[slot].current_savestate_size = newsize;
43884397 savestates[slot].savestate_context_tokens = current_context_tokens;
43894398 printf (" \n KV Save State %d: Created SaveState of %zu tokens, costing %zu MB.\n " ,slot,current_context_tokens.size (),savestates[slot].current_savestate_size /(1024 *1024 ));
43904399 }
4391- return res;
4400+
4401+ if (draft_ctx)
4402+ {
4403+ size_t newsize2 = llama_state_get_size (draft_ctx);
4404+ try {
4405+ if (savestates[slot].current_draft_savestate_buffer .capacity () < newsize2 + 512 ) {
4406+ savestates[slot].current_draft_savestate_buffer = std::vector<uint8_t >(newsize2 + 512 );
4407+ } else {
4408+ savestates[slot].current_draft_savestate_buffer .resize (newsize2 + 512 );
4409+ }
4410+ } catch (const std::bad_alloc&) {
4411+ fprintf (stderr, " KV Save State: Failed to allocate %zu bytes.\n " , newsize2 + 512 );
4412+ return 0 ;
4413+ }
4414+ auto res2 = llama_state_get_data (draft_ctx, savestates[slot].current_draft_savestate_buffer .data (), newsize2);
4415+ if (res2 > 0 ) {
4416+ totalbytes += res2;
4417+ savestates[slot].current_draft_savestate_size = newsize2;
4418+ printf (" \n KV Save State %d: Created DraftSaveState of %zu tokens, costing %zu MB.\n " ,slot,current_context_tokens.size (),savestates[slot].current_draft_savestate_size /(1024 *1024 ));
4419+ }
4420+ }
4421+ return totalbytes;
43924422 }
43934423 return 0 ;
43944424}
@@ -4408,6 +4438,12 @@ bool gpttype_load_state_kv(int slot)
44084438 {
44094439 current_context_tokens = savestates[slot].savestate_context_tokens ;
44104440 printf (" \n KV Load SaveState %d: Restored KV with %zu tokens.\n " , slot,current_context_tokens.size ());
4441+ if (draft_ctx && savestates[slot].current_draft_savestate_size >0 )
4442+ {
4443+ llama_memory_clear (llama_get_memory (draft_ctx),true );
4444+ auto res2 = llama_state_set_data (draft_ctx, savestates[slot].current_draft_savestate_buffer .data (), savestates[slot].current_draft_savestate_size );
4445+ printf (" \n KV Load DraftSaveState %d: Restored KV with %zu tokens.\n " , slot,current_context_tokens.size ());
4446+ }
44114447 }
44124448 return (res > 0 );
44134449 }
@@ -4432,6 +4468,15 @@ bool gpttype_clear_state_kv(bool shrink)
44324468 }
44334469 savestates[slot].savestate_context_tokens .clear ();
44344470 savestates[slot].current_savestate_size = 0 ;
4471+ if (draft_ctx && savestates[slot].current_draft_savestate_size >0 )
4472+ {
4473+ savestates[slot].current_draft_savestate_buffer .clear ();
4474+ if (shrink)
4475+ {
4476+ savestates[slot].current_draft_savestate_buffer .shrink_to_fit ();
4477+ }
4478+ savestates[slot].current_draft_savestate_size = 0 ;
4479+ }
44354480 }
44364481 }
44374482 return true ;
0 commit comments