@@ -17,6 +17,9 @@ struct llama_ubatch;
1717struct  llama_kv_cache  : public  llama_memory_i  {
1818    using  llama_memory_i::llama_memory_i;
1919
20+     virtual  void  restore () = 0; //  call if batch processing fails to restore the cache state
21+     virtual  void  commit () = 0;  //  call after successful batch processing
22+ 
2023    virtual  int32_t   get_n_tokens ()   const  = 0;
2124    virtual  uint32_t  get_used_cells () const  = 0; //  TODO: remove, this is too-specific to the unified cache
2225
@@ -25,9 +28,24 @@ struct llama_kv_cache : public llama_memory_i {
2528    bool  get_can_edit () const  override  { return  get_can_shift (); }
2629};
2730
31+ struct  llama_kv_cache_guard  {
32+     llama_kv_cache_guard (llama_kv_cache * kv) : kv(kv) {}
33+ 
34+     ~llama_kv_cache_guard () {
35+         kv->restore ();
36+     }
37+ 
38+     void  commit () {
39+         kv->commit ();
40+     }
41+ 
42+ private: 
43+     llama_kv_cache * kv;
44+ };
45+ 
2846struct  llama_kv_cell  {
2947    llama_pos pos   = -1 ;
30-     llama_pos delta = 0 ;
48+     llama_pos delta =   0 ;
3149    int32_t    src   = -1 ; //  used by recurrent state models to copy states
3250    int32_t    tail  = -1 ;
3351
@@ -46,17 +64,6 @@ struct llama_kv_cell {
4664    }
4765};
4866
49- //  a structure holds information about the slot found in llama_kv_cache_find_slot
50- struct  llama_kv_cache_slot_info  {
51-     std::pair<uint32_t , uint32_t > boundaries; //  slot boundaries [begin, end)
52-     bool  found = false ;                       //  the slot was found
53- 
54-     explicit  llama_kv_cache_slot_info (bool  found_) : found{found_} {}
55-     llama_kv_cache_slot_info (uint32_t  begin, uint32_t  end) : boundaries{begin, end}, found{true } {}
56- 
57-     operator  bool () const  { return  found; }
58- };
59- 
6067//  ring-buffer of cached KV data
6168//  TODO: pimpl
6269//  TODO: add notion of max sequences
@@ -93,6 +100,9 @@ class llama_kv_cache_unified : public llama_kv_cache {
93100    void  clear () override ;
94101    void  defrag () override ;
95102
103+     virtual  void  restore () override ;
104+     virtual  void  commit () override ;
105+ 
96106    bool  seq_rm   (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override ;
97107    void  seq_cp   (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override ;
98108    void  seq_keep (llama_seq_id seq_id) override ;
@@ -105,10 +115,9 @@ class llama_kv_cache_unified : public llama_kv_cache {
105115
106116    //  find an empty slot of size "n_tokens" in the cache
107117    //  updates the cache head
108-     //  returns a structure holding information about the slot found
109118    //  Note: On success, it's important that cache.head points
110119    //  to the first cell of the slot.
111-     llama_kv_cache_slot_info  find_slot (const  llama_ubatch & batch);
120+     bool  find_slot (const  llama_ubatch & batch);
112121
113122    //  TODO: maybe not needed
114123    uint32_t  get_padding (const  llama_cparams & cparams) const ;
@@ -128,7 +137,18 @@ class llama_kv_cache_unified : public llama_kv_cache {
128137    //  return true if cells have been moved
129138    bool  defrag_prepare (int32_t  n_max_nodes);
130139
131-     //  state save/load
140+     //  commit/restore cache
141+ 
142+     struct  slot_range  {
143+         uint32_t  p0 = 0 ;
144+         uint32_t  p1 = 0 ;
145+     };
146+ 
147+     struct  {
148+         std::vector<slot_range> ranges;
149+     } pending;
150+ 
151+     //  state write/load
132152
133153    void  state_write (llama_io_write_i & io, llama_seq_id seq_id = -1 ) const ;
134154    void  state_read  (llama_io_read_i  & io, llama_seq_id seq_id = -1 );
@@ -183,59 +203,6 @@ class llama_kv_cache_unified : public llama_kv_cache {
183203//     using llama_kv_cache_unified::llama_kv_cache_unified;
184204// };
185205
186- // 
187- //  kv cache restore
188- // 
189- 
190- //  saves the kv_cache state for future recovery.
191- //  used to rollback llama_kv_cache_find_slot changes.
192- struct  llama_kv_slot_restorer  {
193-     struct  llama_kv_cache_state  {
194-         uint32_t  head = 0 ;
195-         uint32_t  n    = 0 ;
196-     } old_state;
197- 
198-     //  for non-recurrent models only
199-     //  list of slots to restore
200-     std::vector<std::pair<uint32_t , uint32_t >> slot_boundaries;
201- 
202-     bool  do_restore = false ;
203- 
204-     llama_kv_cache_unified & cache;
205- 
206-     explicit  llama_kv_slot_restorer (llama_kv_cache_unified & cache) : cache(cache) {
207-         old_state.head  = cache.head ;
208-         old_state.n     = cache.n ;
209-     }
210- 
211-     //  saves a slot information for future restoration
212-     void  save (const  llama_kv_cache_slot_info & slot) {
213-         if  (slot) {
214-             do_restore = true ;
215-             if  (slot.boundaries .first  != slot.boundaries .second ) {
216-                 slot_boundaries.push_back (slot.boundaries );
217-             }
218-         }
219-     }
220- 
221-     //  must be explicitly called to restore the kv_cache state
222-     //  and rollback changes from all llama_kv_cache_find_slot calls
223-     void  restore () {
224-         if  (do_restore) {
225-             cache.head  = old_state.head ;
226-             cache.n     = old_state.n ;
227- 
228-             if  (cache.recurrent ) { //  recurrent models like Mamba or RWKV can't have a state partially erased
229-                 cache.seq_rm (-1 , -1 , -1 );
230-             } else  {
231-                 for  (auto  & slot : slot_boundaries) {
232-                     cache.seq_rm (-1 , slot.first , slot.second );
233-                 }
234-             }
235-         }
236-     }
237- };
238- 
239206//  TODO: maybe become part of the public llama_kv_cache in the future
240207int32_t  llama_kv_cache_n_tokens (const  llama_kv_cache * kv);
241208
0 commit comments