@@ -101,8 +101,12 @@ def max_seq_length(self, value: int) -> None:
101
101
device = self .cos .device ,
102
102
)
103
103
104
- def _are_kv_caches_assigned (self ) -> bool :
105
- return any (block .attn .kv_cache is not None for block in self .transformer .h )
104
+ def are_kv_caches_assigned (self ) -> bool :
105
+ status = [block .attn .kv_cache is not None for block in self .transformer .h ]
106
+ result = any (status )
107
+ if result and not all (status ):
108
+ raise IndexError ("Some layers have KV caches assigned, but not all" )
109
+ return result
106
110
107
111
def assign_kv_caches (
108
112
self , kv_caches : List [KVCache ]
@@ -120,7 +124,7 @@ def assign_kv_caches(
120
124
kv_caches: KV caches, one for each layer of the model
121
125
122
126
"""
123
- if self ._are_kv_caches_assigned ():
127
+ if self .are_kv_caches_assigned ():
124
128
raise ValueError ("Model has KV caches assigned already" )
125
129
if len (kv_caches ) != self .config .n_layer :
126
130
raise ValueError (f"kv_caches must have one entry per layer, so { self .config .n_layer } entries " )
@@ -154,7 +158,7 @@ def set_kv_cache(
154
158
`self.max_seq_length`
155
159
156
160
"""
157
- if self ._are_kv_caches_assigned () and not self ._default_kv_cache :
161
+ if self .are_kv_caches_assigned () and not self ._default_kv_cache :
158
162
raise ValueError ("Model has KV caches assigned already" )
159
163
if max_seq_length is None :
160
164
max_seq_length = self .max_seq_length
@@ -269,15 +273,14 @@ def forward(
269
273
raise ValueError (f"Cannot forward sequence of length { T } , max seq length is only { self .max_seq_length } ." )
270
274
for_prefill = False
271
275
if input_pos is not None :
272
- for_prefill = (input_pos == 0 )
273
276
# Few tokens generation. This needs a KV cache. If none is assigned,
274
277
# the call fails
275
- msg_suffix = f"."
276
- for l_ix , block in enumerate ( self . transformer . h ):
277
- kv_cache = block . attn . kv_cache
278
- if kv_cache is None :
279
- raise ValueError ( "KV caches are not assigned. Assign KV caches with 'assign_kv_caches' or create default caches with 'set_kv_cache'" )
280
- if not for_prefill :
278
+ if not self . are_kv_caches_assigned ():
279
+ raise ValueError ( "KV caches are not assigned. Assign KV caches with 'assign_kv_caches' or create default caches with 'set_kv_cache'" )
280
+ for_prefill = ( input_pos == 0 )
281
+ if not for_prefill :
282
+ for l_ix , block in enumerate ( self . transformer . h ):
283
+ kv_cache = block . attn . kv_cache
281
284
if kv_cache .next_token_pos is None :
282
285
raise ValueError ("Inference calls need to start with pre-fill, i.e. 'input_pos=0'" )
283
286
if kv_cache .next_token_pos != input_pos :
@@ -373,6 +376,7 @@ def clear_kv_cache(self) -> None:
373
376
if self ._default_kv_cache :
374
377
for block in self .transformer .h :
375
378
block .attn .kv_cache = None
379
+ self ._default_kv_cache = False
376
380
377
381
def get_kv_cache_params (self ) -> Optional [KVCacheParams ]:
378
382
kv_cache = self .transformer .h [0 ].attn .kv_cache
0 commit comments