Skip to content

Commit 39b0699

Browse files
committed
fixed savestates with drafting
1 parent df47b51 commit 39b0699

File tree

2 files changed

+52
-5
lines changed

2 files changed

+52
-5
lines changed

gpttype_adapter.cpp

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4341,13 +4341,19 @@ size_t gpttype_calc_new_state_kv()
43414341
}
43424342
if(file_format == FileFormat::GGUF_GENERIC)
43434343
{
4344-
return llama_state_get_size(llama_ctx_v4);
4344+
size_t s1 = llama_state_get_size(llama_ctx_v4);
4345+
if(draft_ctx)
4346+
{
4347+
size_t s2 = llama_state_get_size(draft_ctx);
4348+
s1 += s2;
4349+
}
4350+
return s1;
43454351
}
43464352
return 0;
43474353
}
43484354
size_t gpttype_calc_old_state_kv(int slot)
43494355
{
4350-
return savestates[slot].current_savestate_size;
4356+
return savestates[slot].current_savestate_size + savestates[slot].current_draft_savestate_size;
43514357
}
43524358
size_t gpttype_calc_old_state_tokencount(int slot)
43534359
{
@@ -4365,30 +4371,54 @@ size_t gpttype_save_state_kv(int slot)
43654371
}
43664372
if(file_format == FileFormat::GGUF_GENERIC)
43674373
{
4374+
size_t totalbytes = 0;
43684375
if (!savestates[slot].current_savestate_buffer.empty()) { //JIT free
43694376
savestates[slot].current_savestate_buffer.clear();
4377+
savestates[slot].current_draft_savestate_buffer.clear();
43704378
savestates[slot].savestate_context_tokens.clear();
43714379
savestates[slot].current_savestate_size = 0;
4380+
savestates[slot].current_draft_savestate_size = 0;
43724381
}
43734382
size_t newsize = llama_state_get_size(llama_ctx_v4);
43744383
try {
43754384
if (savestates[slot].current_savestate_buffer.capacity() < newsize + 512) {
4376-
savestates[slot].current_savestate_buffer = std::vector<uint8_t>(newsize + 512);
4385+
savestates[slot].current_savestate_buffer = std::vector<uint8_t>(newsize + 512); // add some padding. May throw std::bad_alloc
43774386
} else {
43784387
savestates[slot].current_savestate_buffer.resize(newsize + 512);
43794388
}
4380-
savestates[slot].current_savestate_buffer.resize(newsize + 512); // add some padding. May throw std::bad_alloc
43814389
} catch (const std::bad_alloc&) {
43824390
fprintf(stderr, "KV Save State: Failed to allocate %zu bytes.\n", newsize + 512);
43834391
return 0;
43844392
}
43854393
auto res = llama_state_get_data(llama_ctx_v4, savestates[slot].current_savestate_buffer.data(), newsize);
43864394
if (res > 0) {
4395+
totalbytes += res;
43874396
savestates[slot].current_savestate_size = newsize;
43884397
savestates[slot].savestate_context_tokens = current_context_tokens;
43894398
printf("\nKV Save State %d: Created SaveState of %zu tokens, costing %zu MB.\n",slot,current_context_tokens.size(),savestates[slot].current_savestate_size/(1024*1024));
43904399
}
4391-
return res;
4400+
4401+
if(draft_ctx)
4402+
{
4403+
size_t newsize2 = llama_state_get_size(draft_ctx);
4404+
try {
4405+
if (savestates[slot].current_draft_savestate_buffer.capacity() < newsize2 + 512) {
4406+
savestates[slot].current_draft_savestate_buffer = std::vector<uint8_t>(newsize2 + 512);
4407+
} else {
4408+
savestates[slot].current_draft_savestate_buffer.resize(newsize2 + 512);
4409+
}
4410+
} catch (const std::bad_alloc&) {
4411+
fprintf(stderr, "KV Save State: Failed to allocate %zu bytes.\n", newsize2 + 512);
4412+
return 0;
4413+
}
4414+
auto res2 = llama_state_get_data(draft_ctx, savestates[slot].current_draft_savestate_buffer.data(), newsize2);
4415+
if (res2 > 0) {
4416+
totalbytes += res2;
4417+
savestates[slot].current_draft_savestate_size = newsize2;
4418+
printf("\nKV Save State %d: Created DraftSaveState of %zu tokens, costing %zu MB.\n",slot,current_context_tokens.size(),savestates[slot].current_draft_savestate_size/(1024*1024));
4419+
}
4420+
}
4421+
return totalbytes;
43924422
}
43934423
return 0;
43944424
}
@@ -4408,6 +4438,12 @@ bool gpttype_load_state_kv(int slot)
44084438
{
44094439
current_context_tokens = savestates[slot].savestate_context_tokens;
44104440
printf("\nKV Load SaveState %d: Restored KV with %zu tokens.\n", slot,current_context_tokens.size());
4441+
if(draft_ctx && savestates[slot].current_draft_savestate_size>0)
4442+
{
4443+
llama_memory_clear(llama_get_memory(draft_ctx),true);
4444+
auto res2 = llama_state_set_data(draft_ctx, savestates[slot].current_draft_savestate_buffer.data(), savestates[slot].current_draft_savestate_size);
4445+
printf("\nKV Load DraftSaveState %d: Restored KV with %zu tokens.\n", slot,current_context_tokens.size());
4446+
}
44114447
}
44124448
return (res > 0);
44134449
}
@@ -4432,6 +4468,15 @@ bool gpttype_clear_state_kv(bool shrink)
44324468
}
44334469
savestates[slot].savestate_context_tokens.clear();
44344470
savestates[slot].current_savestate_size = 0;
4471+
if(draft_ctx && savestates[slot].current_draft_savestate_size>0)
4472+
{
4473+
savestates[slot].current_draft_savestate_buffer.clear();
4474+
if(shrink)
4475+
{
4476+
savestates[slot].current_draft_savestate_buffer.shrink_to_fit();
4477+
}
4478+
savestates[slot].current_draft_savestate_size = 0;
4479+
}
44354480
}
44364481
}
44374482
return true;

otherarch/otherarch.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,8 @@ struct savestate_data
521521
{
522522
size_t current_savestate_size = 0;
523523
std::vector<uint8_t> current_savestate_buffer;
524+
size_t current_draft_savestate_size = 0;
525+
std::vector<uint8_t> current_draft_savestate_buffer;
524526
std::vector<gpt_vocab::id> savestate_context_tokens; //for context clones
525527
};
526528

0 commit comments

Comments
 (0)