@@ -111,7 +111,7 @@ class llama_kv_cache_unified : public llama_kv_cache {
111111
112112 bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override ;
113113 void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override ;
114- void seq_keep (llama_seq_id seq_id) override ;
114+ void seq_keep (llama_seq_id seq_id) override ;
115115 void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override ;
116116 void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override ;
117117
@@ -147,6 +147,15 @@ class llama_kv_cache_unified : public llama_kv_cache {
147147
148148 bool get_can_shift () const override ;
149149
150+ // state write/load
151+
152+ void state_write (llama_io_write_i & io, llama_seq_id seq_id = -1 ) const override ;
153+ void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1 ) override ;
154+
155+ //
156+ // llama_kv_cache_unified specific API
157+ //
158+
150159 uint32_t get_n () const ;
151160
152161 ggml_tensor * get_k (ggml_context * ctx, int32_t il) const ;
@@ -161,11 +170,6 @@ class llama_kv_cache_unified : public llama_kv_cache {
161170 void set_input_k_shift (ggml_tensor * dst) const ;
162171 void set_input_pos_bucket (ggml_tensor * dst, const llama_ubatch * ubatch) const ;
163172
164- // state write/load
165-
166- void state_write (llama_io_write_i & io, llama_seq_id seq_id = -1 ) const override ;
167- void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1 ) override ;
168-
169173private:
170174 const llama_model & model;
171175 const llama_hparams & hparams;
0 commit comments