Skip to content

Commit 8a18e09

Browse files
committed
added smartcaching implementation inspired from Pento95 (+2 squashed commit)
Squashed commit: [fcc4986] wip basic smart caching test [b6e8b25] wip basic smart caching test
1 parent 1aab32f commit 8a18e09

File tree

5 files changed

+154
-29
lines changed

5 files changed

+154
-29
lines changed

expose.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ struct load_model_inputs
7575
const bool check_slowness = false;
7676
const bool highpriority = false;
7777
const bool swa_support = false;
78+
const bool smartcache = false;
7879
const float lora_multiplier = 1.0f;
7980
const bool quiet = false;
8081
const int debugmode = 0;

gpttype_adapter.cpp

Lines changed: 133 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <string>
2222
#include <cctype>
2323
#include <locale>
24+
#include <chrono>
2425

2526
#include "utils.h"
2627

@@ -151,7 +152,7 @@ static int delayed_generated_tokens_limit = 0;
151152
std::deque<std::string> delayed_generated_tokens; //for use with antislop sampling
152153
static std::map<int,std::vector<int>> antislop_banned_token_ids; //first is the npast position, second is the array of banned ids at that index
153154

154-
const int savestate_limit = 3;
155+
const int savestate_limit = 4;
155156
static savestate_data savestates[savestate_limit];
156157

157158
inline int kcpp_cpu_has_blas(void) {
@@ -1826,9 +1827,29 @@ static bool kcpp_eval_image(llama_context * ctx_llama, float * img_embd, int num
18261827
return true;
18271828
}
18281829

1830+
//counts the number of matching prefix tokens between two sequences, returns percentage matched 0.0 to 1.0
1831+
float ComputePrefixMatchPercent(std::vector<int> &current_context_tokens, std::vector<int> &new_context_tokens)
1832+
{
1833+
int match_count = 0;
1834+
size_t min_length = std::min(current_context_tokens.size(), new_context_tokens.size());
1835+
for (size_t i = 0; i < min_length; ++i) {
1836+
if (current_context_tokens[i] == new_context_tokens[i]) {
1837+
match_count++;
1838+
} else {
1839+
break;
1840+
}
1841+
}
1842+
// Handle case where both sequences are empty to avoid division by zero
1843+
if (min_length == 0) {
1844+
return 0.0f; // Both empty sequences are considered 100% matched
1845+
}
1846+
return static_cast<float>(match_count) / static_cast<float>(min_length);
1847+
}
1848+
18291849
//given an old GGUF context and a new context that has some middle portion removed,
18301850
//find and remove the middle portion from the old context from the KV. Does not fast forward after this destructive action
1831-
void PurgeMissingTokens(llama_context * ctx, llama_context * draft_ctx, std::vector<int> &current_context_tokens, std::vector<int> &new_context_tokens, const int genamt, const int nctx)
1851+
//returns true if contextshift is doable, executes it if dryrun is false
1852+
bool DoContextShifting(llama_context * ctx, llama_context * draft_ctx, std::vector<int> &current_context_tokens, std::vector<int> &new_context_tokens, const int genamt, const int nctx, bool dryrun)
18321853
{
18331854
//scan from start old and new ctx, until first mismatch found, save as p0
18341855
//check remaining old and new ctx for longest common subseq, which needs to be at 256 tokens
@@ -1860,11 +1881,9 @@ void PurgeMissingTokens(llama_context * ctx, llama_context * draft_ctx, std::vec
18601881
}
18611882
}
18621883

1863-
//printf("\nPN: %d, NTL: %d, CCT: %d,TS:%d, diff:%d, sft:%d\n",purgeneeded,new_tokens_len,current_context_tokens.size(),trimstart,(new_tokens_len - trimstart),ShortfallThreshold);
1864-
18651884
if(!purgeneeded || new_tokens_len < 6 || current_context_tokens.size() < 6 || new_tokens_len - trimstart < ShortfallThreshold)
18661885
{
1867-
return; //no purge is needed
1886+
return false; //no purge is needed
18681887
}
18691888

18701889
//at least this many tokens need to match, otherwise don't bother trimming
@@ -1881,30 +1900,38 @@ void PurgeMissingTokens(llama_context * ctx, llama_context * draft_ctx, std::vec
18811900
int found = ArrFindIndexOf(current_context_tokens,shared);
18821901
if(found>=0 && found > trimstart)
18831902
{
1884-
1885-
//extract the unwanted tokens out from context and KV
1886-
int diff = found - trimstart;
1887-
llama_memory_seq_rm(llama_get_memory(ctx), 0, trimstart, trimstart + diff);
1888-
llama_memory_seq_add(llama_get_memory(ctx), 0, trimstart + diff, -1, -diff);
1889-
if(draft_ctx)
1903+
if(!dryrun)
18901904
{
1891-
llama_memory_seq_rm(llama_get_memory(draft_ctx), 0, trimstart, trimstart + diff);
1892-
llama_memory_seq_add(llama_get_memory(draft_ctx), 0, trimstart + diff, -1, -diff);
1893-
}
1894-
1895-
for (size_t i = trimstart + diff; i < current_context_tokens.size() - 1; i++)
1896-
{
1897-
current_context_tokens[i - diff] = current_context_tokens[i];
1905+
//extract the unwanted tokens out from context and KV
1906+
int diff = found - trimstart;
1907+
llama_memory_seq_rm(llama_get_memory(ctx), 0, trimstart, trimstart + diff);
1908+
llama_memory_seq_add(llama_get_memory(ctx), 0, trimstart + diff, -1, -diff);
1909+
if(draft_ctx)
1910+
{
1911+
llama_memory_seq_rm(llama_get_memory(draft_ctx), 0, trimstart, trimstart + diff);
1912+
llama_memory_seq_add(llama_get_memory(draft_ctx), 0, trimstart + diff, -1, -diff);
1913+
}
1914+
for (size_t i = trimstart + diff; i < current_context_tokens.size() - 1; i++)
1915+
{
1916+
current_context_tokens[i - diff] = current_context_tokens[i];
1917+
}
1918+
printf("\n[Context Shifting: Erased %d tokens at position %d]", diff, trimstart + 1);
1919+
current_context_tokens.resize(current_context_tokens.size() - diff);
18981920
}
1899-
1900-
printf("\n[Context Shifting: Erased %d tokens at position %d]", diff, trimstart + 1);
1901-
1902-
current_context_tokens.resize(current_context_tokens.size() - diff);
1921+
return true;
19031922
}
19041923
}
1924+
return false;
1925+
1926+
}
19051927

1928+
//returns true if context shifting is possible. does not execute the shift
1929+
bool CanContextShift(std::vector<int> &current_context_tokens, std::vector<int> &new_context_tokens, const int genamt, const int nctx)
1930+
{
1931+
return DoContextShifting(nullptr,nullptr,current_context_tokens,new_context_tokens,genamt,nctx,true);
19061932
}
19071933

1934+
19081935
static int GetBatchSize(int desiredBlasBatchSize,FileFormat in_file_format)
19091936
{
19101937
//check if approved to use BLAS
@@ -1978,6 +2005,12 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
19782005
kcpp_data->use_smartcontext = inputs.use_smartcontext;
19792006
kcpp_data->use_contextshift = inputs.use_contextshift;
19802007
kcpp_data->use_fastforward = inputs.use_fastforward;
2008+
kcpp_data->smartcache = inputs.smartcache;
2009+
if(!kcpp_data->use_fastforward && kcpp_data->smartcache)
2010+
{
2011+
kcpp_data->smartcache = false;
2012+
printf("\nSmartCache IS DISABLED!\nSmartCache requires Fast Forwarding!\n");
2013+
}
19812014
kcpp_data->swa_full = !inputs.swa_support;
19822015
if (!kcpp_data->swa_full) {
19832016
if (inputs.use_contextshift) {
@@ -3776,6 +3809,61 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
37763809
}
37773810
bool blank_prompt = (addedmemory=="" && kcpp_data->prompt=="");
37783811

3812+
//smart cache logic
3813+
if(kcpp_data->smartcache)
3814+
{
3815+
const float similarity_threshold = 0.7f;
3816+
//If CanBeShifted is true, do nothing. Allow shift as normal.
3817+
if(!CanContextShift(current_context_tokens, embd_inp, inputs.max_length, nctx))
3818+
{
3819+
// If CanBeShifted is false, calculate prefix similarity with current_context_tokens of current context
3820+
// If similarity > similarity_threshold, do nothing. Allow fast forward as normal.
3821+
float similarity = ComputePrefixMatchPercent(current_context_tokens,embd_inp);
3822+
if(similarity < similarity_threshold)
3823+
{
3824+
// Otherwise, for each of the currently used kv state slots, calculate ComputePrefixMatch and CanBeShifted
3825+
// If similarity to any of them > similarity_threshold or CanBeShifted, save current slot and switch to that slot.
3826+
// Whenever loading or saving current slot, simply tag the slot with a timestamp. When running out of slots after all 3 are used, delete the oldest timestamped slot.
3827+
// Slot loading and saving completely reuses gpttype_load_state_kv and gpttype_save_state_kv, nothing else is needed.
3828+
bool foundswap = false;
3829+
for(int i=0;i<savestate_limit;++i)
3830+
{
3831+
float similaritybeat = ComputePrefixMatchPercent(savestates[i].savestate_context_tokens,embd_inp);
3832+
if(similaritybeat > similarity_threshold || CanContextShift(savestates[i].savestate_context_tokens, embd_inp, inputs.max_length, nctx))
3833+
{
3834+
//found a match. save to the oldest slot thats not the one we are loading
3835+
int oldest_slot = get_oldest_slot(i);
3836+
if(oldest_slot!=i)
3837+
{
3838+
if(current_context_tokens.size()>32) //do not save tiny contexts
3839+
{
3840+
printf("\n[SmartCache Match of %.2f in slot %d. Saving into slot %d and switching...]",similaritybeat,i,oldest_slot);
3841+
gpttype_save_state_kv(oldest_slot);
3842+
}
3843+
else
3844+
{
3845+
printf("\n[SmartCache Match of %.2f in slot %d. Switching...]",similaritybeat,i);
3846+
3847+
}
3848+
gpttype_load_state_kv(i);
3849+
foundswap = true;
3850+
break;
3851+
}
3852+
}
3853+
}
3854+
if(!foundswap) //could not match anything, just save kv and continue
3855+
{
3856+
if(current_context_tokens.size()>32) //do not save tiny contexts
3857+
{
3858+
int oldest_slot = get_oldest_slot(-1);
3859+
printf("\n[SmartCache No Match, Saving into slot %d...]",oldest_slot);
3860+
gpttype_save_state_kv(oldest_slot);
3861+
}
3862+
}
3863+
}
3864+
}
3865+
}
3866+
37793867
if (file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2 || is_recurrent)
37803868
{
37813869
if(!blank_prompt)
@@ -3825,7 +3913,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
38253913
{
38263914
if(kcpp_data->use_fastforward && kcpp_data->use_contextshift && (file_format == FileFormat::GGUF_GENERIC))
38273915
{
3828-
PurgeMissingTokens(llama_ctx_v4, draft_ctx, current_context_tokens, embd_inp, inputs.max_length, nctx);
3916+
DoContextShifting(llama_ctx_v4, draft_ctx, current_context_tokens, embd_inp, inputs.max_length, nctx, false);
38293917
triggersc = false;
38303918
}
38313919
if(kcpp_data->use_fastforward)
@@ -4709,6 +4797,9 @@ size_t gpttype_save_state_kv(int slot)
47094797
totalbytes += res;
47104798
savestates[slot].current_savestate_size = newsize;
47114799
savestates[slot].savestate_context_tokens = current_context_tokens;
4800+
auto timenow = std::chrono::system_clock::now();
4801+
auto timestamp = std::chrono::duration_cast<std::chrono::seconds>(timenow.time_since_epoch()).count();
4802+
savestates[slot].last_used = timestamp;
47124803
printf("\nKV Save State %d: Created SaveState of %zu tokens, costing %zu MB.\n",slot,current_context_tokens.size(),savestates[slot].current_savestate_size/(1024*1024));
47134804
}
47144805

@@ -4758,6 +4849,9 @@ bool gpttype_load_state_kv(int slot)
47584849
auto res2 = llama_state_set_data(draft_ctx, savestates[slot].current_draft_savestate_buffer.data(), savestates[slot].current_draft_savestate_size);
47594850
printf("\nKV Load DraftSaveState %d: Restored KV with %zu tokens.\n", slot,current_context_tokens.size());
47604851
}
4852+
auto timenow = std::chrono::system_clock::now();
4853+
auto timestamp = std::chrono::duration_cast<std::chrono::seconds>(timenow.time_since_epoch()).count();
4854+
savestates[slot].last_used = timestamp;
47614855
}
47624856
return (res > 0);
47634857
}
@@ -4791,9 +4885,25 @@ bool gpttype_clear_state_kv(bool shrink)
47914885
}
47924886
savestates[slot].current_draft_savestate_size = 0;
47934887
}
4888+
savestates[slot].last_used = 0;
47944889
}
47954890
}
47964891
return true;
47974892
}
47984893
return false;
47994894
}
4895+
4896+
int get_oldest_slot(int excludeSlotId)
4897+
{
4898+
int64_t slotage = INT64_MAX; // Initialize with maximum possible value
4899+
int slotid = 0;
4900+
for(int i=0;i<savestate_limit;++i)
4901+
{
4902+
if(savestates[i].last_used <= slotage && i!=excludeSlotId)
4903+
{
4904+
slotage = savestates[i].last_used;
4905+
slotid = i;
4906+
}
4907+
}
4908+
return slotid;
4909+
}

koboldcpp.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
default_ttsmaxlen = 4096
5454
default_visionmaxres = 1024
5555
net_save_slots = 12
56-
savestate_limit = 3 #3 savestate slots
56+
savestate_limit = 4 #savestate slots
5757
default_vae_tile_threshold = 768
5858
default_native_ctx = 16384
5959
overridekv_max = 4
@@ -217,6 +217,7 @@ class load_model_inputs(ctypes.Structure):
217217
("check_slowness", ctypes.c_bool),
218218
("highpriority", ctypes.c_bool),
219219
("swa_support", ctypes.c_bool),
220+
("smartcache", ctypes.c_bool),
220221
("lora_multiplier", ctypes.c_float),
221222
("quiet", ctypes.c_bool),
222223
("debugmode", ctypes.c_int)]
@@ -1519,6 +1520,7 @@ def load_model(model_filename):
15191520
inputs.check_slowness = (not args.highpriority and os.name == 'nt' and 'Intel' in platform.processor())
15201521
inputs.highpriority = args.highpriority
15211522
inputs.swa_support = args.useswa
1523+
inputs.smartcache = args.smartcache
15221524
inputs = set_backend_props(inputs)
15231525
ret = handle.load_model(inputs)
15241526
return ret
@@ -4344,7 +4346,7 @@ def do_POST(self):
43444346
if self.path.endswith('/api/admin/check_state'):
43454347
if global_memory and args.admin and args.admindir and os.path.exists(args.admindir) and self.check_header_password(args.adminpassword):
43464348
cur_states = []
4347-
for sl in range(savestate_limit): #0,1,2
4349+
for sl in range(savestate_limit): #0,1,2,3
43484350
oldstate = handle.calc_old_state_kv(sl)
43494351
oldtokencnt = handle.calc_old_state_tokencount(sl)
43504352
cur_states.append({"tokens":oldtokencnt,"size":oldstate})
@@ -4997,8 +4999,8 @@ def get_problematic_scaler():
49974999
import customtkinter as ctk
49985000
nextstate = 0 #0=exit, 1=launch
49995001
corrupt_scaler = get_problematic_scaler()
5000-
original_windowwidth = int(860 if corrupt_scaler else 580)
5001-
original_windowheight = int(740 if corrupt_scaler else 580)
5002+
original_windowwidth = int(860 if corrupt_scaler else 584)
5003+
original_windowheight = int(740 if corrupt_scaler else 584)
50025004
windowwidth = original_windowwidth
50035005
windowheight = original_windowheight
50045006
ctk.set_appearance_mode("dark")
@@ -5160,6 +5162,7 @@ def hide_tooltip(event):
51605162
contextshift_var = ctk.IntVar(value=1)
51615163
fastforward_var = ctk.IntVar(value=1)
51625164
swa_var = ctk.IntVar(value=0)
5165+
smartcache_var = ctk.IntVar(value=0)
51635166
remotetunnel_var = ctk.IntVar(value=0)
51645167
smartcontext_var = ctk.IntVar()
51655168
flashattention_var = ctk.IntVar(value=0)
@@ -5626,6 +5629,10 @@ def toggleswa(a,b,c):
56265629
if swa_var.get()==1:
56275630
contextshift_var.set(0)
56285631

5632+
def togglesmartcache(a,b,c):
5633+
if smartcache_var.get()==1:
5634+
fastforward_var.set(1)
5635+
56295636
def togglefastforward(a,b,c):
56305637
if fastforward_var.get()==0:
56315638
contextshift_var.set(0)
@@ -5845,6 +5852,7 @@ def changerunmode(a,b,c):
58455852
makecheckbox(tokens_tab, "Use ContextShift", contextshift_var, 2,tooltiptxt="Uses Context Shifting to reduce reprocessing.\nRecommended. Check the wiki for more info.", command=togglectxshift)
58465853
makecheckbox(tokens_tab, "Use FastForwarding", fastforward_var, 3,tooltiptxt="Use fast forwarding to recycle previous context (always reprocess if disabled).\nRecommended.", command=togglefastforward)
58475854
makecheckbox(tokens_tab, "Use Sliding Window Attention (SWA)", swa_var, 4,tooltiptxt="Allows Sliding Window Attention (SWA) KV Cache, which saves memory but cannot be used with context shifting.", command=toggleswa)
5855+
makecheckbox(tokens_tab, "Use SmartCache", smartcache_var, 5,tooltiptxt="Enables intelligent context switching by saving KV cache snapshots to RAM. Requires fast forwarding.", command=togglesmartcache)
58485856

58495857
# context size
58505858
makeslider(tokens_tab, "Context Size:",contextsize_text, context_var, 0, len(contextsize_text)-1, 18, width=280, set=7,tooltip="What is the maximum context size to support. Model specific. You cannot exceed it.\nLarger contexts require more memory, and not all models support it.")
@@ -5854,7 +5862,7 @@ def changerunmode(a,b,c):
58545862

58555863
nativectx_entry, nativectx_label = makelabelentry(tokens_tab, "Override Native Context:", customrope_nativectx, row=23, padx=(246 if corrupt_scaler else 146), singleline=True, tooltip="Overrides the native trained context of the loaded model with a custom value to be used for Rope scaling.")
58565864
customrope_scale_entry, customrope_scale_label = makelabelentry(tokens_tab, "RoPE Scale:", customrope_scale, row=23, padx=(160 if corrupt_scaler else 100), singleline=True, tooltip="For Linear RoPE scaling. RoPE frequency scale.")
5857-
customrope_base_entry, customrope_base_label = makelabelentry(tokens_tab, "RoPE Base:", customrope_base, row=24, padx=(160 if corrupt_scaler else 100), singleline=True, tooltip="For NTK Aware Scaling. RoPE frequency base.")
5865+
customrope_base_entry, customrope_base_label = makelabelentry(tokens_tab, "Base:", customrope_base, row=23, padx=(420 if corrupt_scaler else 220), singleline=True, tooltip="For NTK Aware Scaling. RoPE frequency base.",labelpadx=(280 if corrupt_scaler else 180))
58585866
def togglerope(a,b,c):
58595867
if customrope_var.get() == 1:
58605868
manualropebox.grid()
@@ -6143,6 +6151,7 @@ def export_vars():
61436151
args.noshift = contextshift_var.get()==0
61446152
args.nofastforward = fastforward_var.get()==0
61456153
args.useswa = swa_var.get()==1
6154+
args.smartcache = smartcache_var.get()==1
61466155
args.remotetunnel = remotetunnel_var.get()==1
61476156
args.foreground = keepforeground.get()==1
61486157
args.cli = terminalonly.get()==1
@@ -6364,6 +6373,7 @@ def import_vars(dict):
63646373
contextshift_var.set(0 if "noshift" in dict and dict["noshift"] else 1)
63656374
fastforward_var.set(0 if "nofastforward" in dict and dict["nofastforward"] else 1)
63666375
swa_var.set(1 if "useswa" in dict and dict["useswa"] else 0)
6376+
smartcache_var.set(1 if "smartcache" in dict and dict["smartcache"] else 0)
63676377
remotetunnel_var.set(1 if "remotetunnel" in dict and dict["remotetunnel"] else 0)
63686378
keepforeground.set(1 if "foreground" in dict and dict["foreground"] else 0)
63696379
terminalonly.set(1 if "cli" in dict and dict["cli"] else 0)
@@ -8335,6 +8345,7 @@ def range_checker(arg: str):
83358345
advparser.add_argument("--noshift","--no-context-shift", help="If set, do not attempt to Trim and Shift the GGUF context.", action='store_true')
83368346
advparser.add_argument("--nofastforward", help="If set, do not attempt to fast forward GGUF context (always reprocess). Will also enable noshift", action='store_true')
83378347
advparser.add_argument("--useswa", help="If set, allows Sliding Window Attention (SWA) KV Cache, which saves memory but cannot be used with context shifting.", action='store_true')
8348+
advparser.add_argument("--smartcache", help="Enables intelligent context switching by saving KV cache snapshots to RAM. Requires fast forwarding.", action='store_true')
83388349
advparser.add_argument("--ropeconfig", help="If set, uses customized RoPE scaling from configured frequency scale and frequency base (e.g. --ropeconfig 0.25 10000). Otherwise, uses NTK-Aware scaling set automatically based on context size. For linear rope, simply set the freq-scale and ignore the freq-base",metavar=('[rope-freq-scale]', '[rope-freq-base]'), default=[0.0, 10000.0], type=float, nargs='+')
83398350
advparser.add_argument("--overridenativecontext", help="Overrides the native trained context of the loaded model with a custom value to be used for Rope scaling.",metavar=('[trained context]'), type=int, default=0)
83408351
compatgroup3 = advparser.add_mutually_exclusive_group()

0 commit comments

Comments
 (0)