diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index 5ffd25f2c7f..fb1a05f4cc9 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -549,7 +549,7 @@ def _update_states(self, attn_updates, update_pos, update_len): style=self.style, update_pos=update_pos, update_len=update_len, - ) + ).detach() for cache_id, update in v_cache_updates.items(): self.v_caches[cache_id] = StaticKVCache.apply_update( self.v_caches[cache_id], @@ -558,7 +558,7 @@ def _update_states(self, attn_updates, update_pos, update_len): style=self.style, update_pos=update_pos, update_len=update_len, - ) + ).detach() self.pos += update_len def _get_lookahead_decoding_mask(