@@ -10,7 +10,7 @@ use super::with_tracing::QMatMul;
1010use crate :: { quantized_nn:: RmsNorm , utils:: repeat_kv} ;
1111use candle:: quantized:: { gguf_file, QTensor } ;
1212use candle:: { DType , Device , Result , Tensor } ;
13- use candle_nn:: { kv_cache:: KvCache , Activation , Embedding , Module } ;
13+ use candle_nn:: { kv_cache:: ConcatKvCache , Activation , Embedding , Module } ;
1414use std:: io:: { Read , Seek } ;
1515use std:: sync:: Arc ;
1616
@@ -136,7 +136,7 @@ struct AttentionWeights {
136136 num_kv_groups : usize ,
137137 head_dim : usize ,
138138 rotary_emb : Arc < RotaryEmbedding > ,
139- kv_cache : KvCache ,
139+ kv_cache : ConcatKvCache ,
140140 span_attn : tracing:: Span ,
141141}
142142
@@ -160,9 +160,7 @@ impl AttentionWeights {
160160 let q_norm = gg. rms_norm ( & format ! ( "{prefix}.attn_q_norm.weight" ) , rms_norm_eps) ?;
161161 let k_norm = gg. rms_norm ( & format ! ( "{prefix}.attn_k_norm.weight" ) , rms_norm_eps) ?;
162162
163- // Initialize KV cache with 512 tokens capacity to reduce initial memory allocation.
164- // The cache will grow in chunks of 512 tokens when needed.
165- let kv_cache = KvCache :: new ( 2 , 512 ) ;
163+ let kv_cache = ConcatKvCache :: new ( 2 ) ;
166164
167165 let span_attn = tracing:: span!( tracing:: Level :: TRACE , "attn" ) ;
168166
@@ -211,10 +209,6 @@ impl AttentionWeights {
211209
212210 let ( q, k) = self . rotary_emb . apply ( & q, & k, offset) ?;
213211
214- // Reset KV cache if we're at the first position
215- if offset == 0 {
216- self . kv_cache . reset ( ) ;
217- }
218212 let ( k, v) = self . kv_cache . append ( & k. contiguous ( ) ?, & v. contiguous ( ) ?) ?;
219213
220214 // Make tensor contiguous to avoid some strided copies
0 commit comments