Skip to content

Commit e9ae0cb

Browse files
committed
added support for RNN models in smartcache
1 parent cde4791 commit e9ae0cb

File tree

2 files changed

+158
-20
lines changed

2 files changed

+158
-20
lines changed

gpttype_adapter.cpp

Lines changed: 155 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1887,11 +1887,25 @@ float ComputePrefixMatchPercent(std::vector<int> &current_context_tokens, std::v
18871887
}
18881888
// Handle case where both sequences are empty to avoid division by zero
18891889
if (min_length == 0) {
1890-
return 0.0f; // Both empty sequences are considered 100% matched
1890+
return 0.0f; // Both empty sequences are considered not matched
18911891
}
18921892
return static_cast<float>(match_count) / static_cast<float>(min_length);
18931893
}
18941894

1895+
//returns true if and only if sequence 1 is fully contained within the starting of sequence 2
1896+
bool FullyContainedPrefix(std::vector<int> &sequence1, std::vector<int> &sequence2)
1897+
{
1898+
if (sequence1.size() > sequence2.size() || sequence1.size()==0 || sequence2.size()==0) {
1899+
return false;
1900+
}
1901+
for (size_t i = 0; i < sequence1.size(); ++i) {
1902+
if (sequence1[i] != sequence2[i]) {
1903+
return false;
1904+
}
1905+
}
1906+
return true;
1907+
}
1908+
18951909
//given an old GGUF context and a new context that has some middle portion removed,
18961910
//find and remove the middle portion from the old context from the KV. Does not fast forward after this destructive action
18971911
//returns true if contextshift is doable, executes it if dryrun is false
@@ -3921,20 +3935,83 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
39213935
{
39223936
shiftable = false;
39233937
}
3924-
const float similarity_threshold = 0.7f;
3925-
//If CanBeShifted is true, do nothing. Allow shift as normal.
3926-
if(!(shiftable && CanContextShift(current_context_tokens, embd_inp, inputs.max_length, nctx)))
3938+
3939+
//we handle recurrent models differently since they require a full subset match
3940+
if(is_recurrent)
3941+
{
3942+
bool curr_usable = FullyContainedPrefix(current_context_tokens,embd_inp);
3943+
if(!curr_usable)
3944+
{
3945+
//see if we have any other usable contexts out there
3946+
int bestslot = -1;
3947+
int bestlen = 0;
3948+
int identical_slot = get_identical_existing_slot(); //see if the slot already exists
3949+
for(int i=0;i<savestate_limit;++i)
3950+
{
3951+
bool target_usable = FullyContainedPrefix(savestates[i].savestate_context_tokens,embd_inp);
3952+
int target_len = savestates[i].savestate_context_tokens.size();
3953+
if(target_usable && target_len>bestlen)
3954+
{
3955+
bestlen = target_len;
3956+
bestslot = i;
3957+
}
3958+
}
3959+
if(bestslot!=-1) //found a good slot to load
3960+
{
3961+
int oldest_slot = get_oldest_slot(bestslot);
3962+
if(oldest_slot!=bestslot)
3963+
{
3964+
if(current_context_tokens.size() > 32) //do not save tiny contexts
3965+
{
3966+
if(identical_slot==-1)
3967+
{
3968+
printf("\n[SmartCache RNN Match of %d tokens in slot %d. Saving into slot %d and switching...]\n",bestlen,bestslot,oldest_slot);
3969+
gpttype_save_state_kv(oldest_slot);
3970+
} else {
3971+
printf("\n[SmartCache RNN Match of %d tokens in slot %d. Already saved in slot %d, switching...]\n",bestlen,bestslot,identical_slot);
3972+
touch_slot(identical_slot);
3973+
}
3974+
}
3975+
else
3976+
{
3977+
printf("\n[SmartCache RNN Match of %d tokens in slot %d. Switching...]\n",bestlen,bestslot);
3978+
}
3979+
gpttype_load_state_kv(bestslot);
3980+
}
3981+
}
3982+
else
3983+
{
3984+
if(current_context_tokens.size() > 32) //do not save tiny contexts
3985+
{
3986+
if(identical_slot==-1)
3987+
{
3988+
int oldest_slot = get_oldest_slot(-1);
3989+
printf("\n[SmartCache RNN No Match, Saving into slot %d...]\n",oldest_slot);
3990+
gpttype_save_state_kv(oldest_slot);
3991+
}
3992+
else
3993+
{
3994+
printf("\n[SmartCache RNN No Match, Already saved in slot %d]\n",identical_slot);
3995+
touch_slot(identical_slot);
3996+
}
3997+
}
3998+
}
3999+
}
4000+
}
4001+
else if(!(shiftable && CanContextShift(current_context_tokens, embd_inp, inputs.max_length, nctx))) //If CanBeShifted is true, do nothing. Allow shift as normal.
39274002
{
39284003
// If CanBeShifted is false, calculate prefix similarity with current_context_tokens of current context
39294004
// If similarity > similarity_threshold, do nothing. Allow fast forward as normal.
39304005
float similarity = ComputePrefixMatchPercent(current_context_tokens,embd_inp);
4006+
const float similarity_threshold = 0.7f;
39314007
if(similarity < similarity_threshold)
39324008
{
39334009
// Otherwise, for each of the currently used kv state slots, calculate ComputePrefixMatch and CanBeShifted
39344010
// If similarity to any of them > similarity_threshold or CanBeShifted, save current slot and switch to that slot.
39354011
// Whenever loading or saving current slot, simply tag the slot with a timestamp. When running out of slots after all 3 are used, delete the oldest timestamped slot.
39364012
// Slot loading and saving completely reuses gpttype_load_state_kv and gpttype_save_state_kv, nothing else is needed.
39374013
bool foundswap = false;
4014+
int identical_slot = get_identical_existing_slot(); //see if a slot already exists with identical data to current
39384015
for(int i=0;i<savestate_limit;++i)
39394016
{
39404017
float similaritybeat = ComputePrefixMatchPercent(savestates[i].savestate_context_tokens,embd_inp);
@@ -3944,15 +4021,20 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
39444021
int oldest_slot = get_oldest_slot(i);
39454022
if(oldest_slot!=i)
39464023
{
3947-
if(current_context_tokens.size()>32) //do not save tiny contexts
4024+
if(current_context_tokens.size() > 32) //do not save tiny contexts
39484025
{
3949-
printf("\n[SmartCache Match of %.2f in slot %d. Saving into slot %d and switching...]",similaritybeat,i,oldest_slot);
3950-
gpttype_save_state_kv(oldest_slot);
4026+
if(identical_slot==-1)
4027+
{
4028+
printf("\n[SmartCache Match of %.2f in slot %d. Saving into slot %d and switching...]\n",similaritybeat,i,oldest_slot);
4029+
gpttype_save_state_kv(oldest_slot);
4030+
} else {
4031+
printf("\n[SmartCache Match of %.2f in slot %d. Already saved in slot %d, switching...]\n",similaritybeat,i,identical_slot);
4032+
touch_slot(identical_slot);
4033+
}
39514034
}
39524035
else
39534036
{
3954-
printf("\n[SmartCache Match of %.2f in slot %d. Switching...]",similaritybeat,i);
3955-
4037+
printf("\n[SmartCache Match of %.2f in slot %d. Switching...]\n",similaritybeat,i);
39564038
}
39574039
gpttype_load_state_kv(i);
39584040
foundswap = true;
@@ -3962,11 +4044,19 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
39624044
}
39634045
if(!foundswap) //could not match anything, just save kv and continue
39644046
{
3965-
if(current_context_tokens.size()>32) //do not save tiny contexts
4047+
if(current_context_tokens.size() > 32) //do not save tiny contexts
39664048
{
3967-
int oldest_slot = get_oldest_slot(-1);
3968-
printf("\n[SmartCache No Match, Saving into slot %d...]",oldest_slot);
3969-
gpttype_save_state_kv(oldest_slot);
4049+
if(identical_slot==-1)
4050+
{
4051+
int oldest_slot = get_oldest_slot(-1);
4052+
printf("\n[SmartCache No Match, Saving into slot %d...]\n",oldest_slot);
4053+
gpttype_save_state_kv(oldest_slot);
4054+
}
4055+
else
4056+
{
4057+
printf("\n[SmartCache No Match, Already saved in slot %d]\n",identical_slot);
4058+
touch_slot(identical_slot);
4059+
}
39704060
}
39714061
}
39724062
}
@@ -4790,6 +4880,21 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
47904880
delayed_generated_tokens.pop_front();
47914881
}
47924882

4883+
//if running rnn model in smartcache mode, save progress after each gen
4884+
if(kcpp_data->smartcache && is_recurrent && file_format==FileFormat::GGUF_GENERIC && current_context_tokens.size() > 32)
4885+
{
4886+
int identical_slot = get_identical_existing_slot();
4887+
if(identical_slot==-1)
4888+
{
4889+
int oldest_slot = get_oldest_slot(-1);
4890+
gpttype_save_state_kv(oldest_slot);
4891+
}
4892+
else
4893+
{
4894+
touch_slot(identical_slot);
4895+
}
4896+
}
4897+
47934898
if(debugmode==1 && !is_quiet && file_format == FileFormat::GGUF_GENERIC)
47944899
{
47954900
printf("\n");
@@ -4907,9 +5012,7 @@ size_t gpttype_save_state_kv(int slot)
49075012
totalbytes += res;
49085013
savestates[slot].current_savestate_size = newsize;
49095014
savestates[slot].savestate_context_tokens = current_context_tokens;
4910-
auto timenow = std::chrono::system_clock::now();
4911-
auto timestamp = std::chrono::duration_cast<std::chrono::seconds>(timenow.time_since_epoch()).count();
4912-
savestates[slot].last_used = timestamp;
5015+
touch_slot(slot);
49135016
printf("\nKV Save State %d: Created SaveState of %zu tokens, costing %zu MB.\n",slot,current_context_tokens.size(),savestates[slot].current_savestate_size/(1024*1024));
49145017
}
49155018

@@ -4959,9 +5062,7 @@ bool gpttype_load_state_kv(int slot)
49595062
auto res2 = llama_state_set_data(draft_ctx, savestates[slot].current_draft_savestate_buffer.data(), savestates[slot].current_draft_savestate_size);
49605063
printf("\nKV Load DraftSaveState %d: Restored KV with %zu tokens.\n", slot,current_context_tokens.size());
49615064
}
4962-
auto timenow = std::chrono::system_clock::now();
4963-
auto timestamp = std::chrono::duration_cast<std::chrono::seconds>(timenow.time_since_epoch()).count();
4964-
savestates[slot].last_used = timestamp;
5065+
touch_slot(slot);
49655066
}
49665067
return (res > 0);
49675068
}
@@ -5002,6 +5103,41 @@ bool gpttype_clear_state_kv(bool shrink)
50025103
}
50035104
return false;
50045105
}
5106+
void touch_slot(int slot) //update the slot's last used time and nothing else
5107+
{
5108+
auto timenow = std::chrono::system_clock::now();
5109+
auto timestamp = std::chrono::duration_cast<std::chrono::seconds>(timenow.time_since_epoch()).count();
5110+
savestates[slot].last_used = timestamp;
5111+
}
5112+
int get_identical_existing_slot() //returns slot number of slot containing exactly the same data, or -1 if nothing
5113+
{
5114+
int64_t slotage = INT64_MAX; // Initialize with maximum possible value
5115+
int slotid = -1;
5116+
int currctxsize = current_context_tokens.size();
5117+
for(int i=0;i<savestate_limit;++i)
5118+
{
5119+
if(savestates[i].savestate_context_tokens.size() == currctxsize)
5120+
{
5121+
bool is_identical = true;
5122+
const auto& slot_tokens = savestates[i].savestate_context_tokens;
5123+
for (size_t j = 0; j < currctxsize; ++j)
5124+
{
5125+
if (slot_tokens[j] != current_context_tokens[j])
5126+
{
5127+
is_identical = false;
5128+
break;
5129+
}
5130+
}
5131+
5132+
if (is_identical)
5133+
{
5134+
slotid = i;
5135+
break;
5136+
}
5137+
}
5138+
}
5139+
return slotid;
5140+
}
50055141

50065142
int get_oldest_slot(int excludeSlotId)
50075143
{

model_adapter.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,4 +140,6 @@ size_t gpttype_calc_old_state_tokencount(int slot);
140140
size_t gpttype_save_state_kv(int slot);
141141
bool gpttype_load_state_kv(int slot);
142142
bool gpttype_clear_state_kv(bool shrink);
143-
int get_oldest_slot(int excludeSlotId);
143+
int get_oldest_slot(int excludeSlotId);
144+
void touch_slot(int slot);
145+
int get_identical_existing_slot();

0 commit comments

Comments
 (0)