55import math
66import multiprocessing
77from typing import List , Optional , Union , Generator , Sequence , Iterator , Deque , Tuple
8- from collections import deque
8+ from collections import deque , OrderedDict
99
1010from . import llama_cpp
1111from .llama_types import *
1414class LlamaCache :
1515 """Cache for a llama.cpp model."""
1616
17- def __init__ (self ):
18- self .cache_state : Dict [Tuple [llama_cpp .llama_token , ...], "LlamaState" ] = dict ()
17+ def __init__ (self , capacity_bytes : int = (2 << 30 )):
18+ self .cache_state : OrderedDict [
19+ Tuple [llama_cpp .llama_token , ...], "LlamaState"
20+ ] = OrderedDict ()
21+ self .capacity_bytes = capacity_bytes
1922
20- def _sorted_keys (self ) -> List [Tuple [llama_cpp .llama_token , ...]]:
21- return [
22- key
23- for _ , key in sorted (
24- ((len (key ), key ) for key in self .cache_state .keys ()), reverse = True
25- )
26- ]
23+ @property
24+ def cache_size (self ):
25+ return sum ([state .llama_state_size for state in self .cache_state .values ()])
2726
28- def _find_key (
29- self , key : Tuple [llama_cpp .llama_token , ...]
27+ def _find_longest_prefix_key (
28+ self ,
29+ key : Tuple [llama_cpp .llama_token , ...],
3030 ) -> Optional [Tuple [llama_cpp .llama_token , ...]]:
31- for k in self ._sorted_keys ():
32- if key [: len (k )] == k :
33- return k
34- return None
31+ min_len = 0
32+ min_key = None
33+ keys = (
34+ (k , Llama .longest_token_prefix (k , key )) for k in self .cache_state .keys ()
35+ )
36+ for k , prefix_len in keys :
37+ if prefix_len > min_len :
38+ min_len = prefix_len
39+ min_key = k
40+ return min_key
3541
3642 def __getitem__ (self , key : Sequence [llama_cpp .llama_token ]) -> "LlamaState" :
37- _key = self ._find_key (tuple (key ))
43+ key = tuple (key )
44+ _key = self ._find_longest_prefix_key (key )
3845 if _key is None :
39- raise KeyError (f"Key not found: { key } " )
40- return self .cache_state [_key ]
46+ raise KeyError (f"Key not found" )
47+ value = self .cache_state [_key ]
48+ self .cache_state .move_to_end (_key )
49+ return value
4150
4251 def __contains__ (self , key : Sequence [llama_cpp .llama_token ]) -> bool :
43- return self ._find_key (tuple (key )) is not None
52+ return self ._find_longest_prefix_key (tuple (key )) is not None
4453
4554 def __setitem__ (self , key : Sequence [llama_cpp .llama_token ], value : "LlamaState" ):
46- self .cache_state = dict () # NOTE: Currently limit to one cache entry.
47- self .cache_state [tuple (key )] = value
55+ key = tuple (key )
56+ if key in self .cache_state :
57+ del self .cache_state [key ]
58+ self .cache_state [key ] = value
59+ while self .cache_size > self .capacity_bytes :
60+ self .cache_state .popitem (last = False )
4861
4962
5063class LlamaState :
@@ -53,7 +66,7 @@ def __init__(
5366 eval_tokens : Deque [llama_cpp .llama_token ],
5467 eval_logits : Deque [List [float ]],
5568 llama_state , # type: llama_cpp.Array[llama_cpp.c_uint8]
56- llama_state_size : llama_cpp . c_size_t ,
69+ llama_state_size : int ,
5770 ):
5871 self .eval_tokens = eval_tokens
5972 self .eval_logits = eval_logits
@@ -526,10 +539,22 @@ def _create_completion(
526539 "logprobs is not supported for models created with logits_all=False"
527540 )
528541
529- if self .cache and prompt_tokens in self .cache :
530- if self .verbose :
531- print ("Llama._create_completion: cache hit" , file = sys .stderr )
532- self .load_state (self .cache [prompt_tokens ])
542+ if self .cache :
543+ try :
544+ cache_item = self .cache [prompt_tokens ]
545+ cache_prefix_len = Llama .longest_token_prefix (
546+ cache_item .eval_tokens , prompt_tokens
547+ )
548+ eval_prefix_len = Llama .longest_token_prefix (
549+ self .eval_tokens , prompt_tokens
550+ )
551+ if cache_prefix_len > eval_prefix_len :
552+ self .load_state (cache_item )
553+ if self .verbose :
554+ print ("Llama._create_completion: cache hit" , file = sys .stderr )
555+ except KeyError :
556+ if self .verbose :
557+ print ("Llama._create_completion: cache miss" , file = sys .stderr )
533558
534559 finish_reason = "length"
535560 multibyte_fix = 0
@@ -1004,3 +1029,15 @@ def logits_to_logprobs(logits: List[float]) -> List[float]:
10041029 exps = [math .exp (float (x )) for x in logits ]
10051030 sum_exps = sum (exps )
10061031 return [math .log (x / sum_exps ) for x in exps ]
1032+
1033+ @staticmethod
1034+ def longest_token_prefix (
1035+ a : Sequence [llama_cpp .llama_token ], b : Sequence [llama_cpp .llama_token ]
1036+ ):
1037+ longest_prefix = 0
1038+ for _a , _b in zip (a , b ):
1039+ if _a == _b :
1040+ longest_prefix += 1
1041+ else :
1042+ break
1043+ return longest_prefix
0 commit comments