@@ -55,8 +55,11 @@ def __init__(
5555 self .register_buffer (
5656 "v_cache" , torch .zeros (cache_shape , dtype = dtype ), persistent = False
5757 )
58+ # We use "kv_cache_pos" here instead of "cache_pos" since the latter is too generic, and we have
59+ # a InitMutableBuferPass that needs to single out this buffer to initialize (and not others)
60+ # since it takes up space in the pte file.
5861 self .register_buffer (
59- "cache_pos " , torch .arange (0 , self .max_seq_len ), persistent = False
62+ "kv_cache_pos " , torch .arange (0 , self .max_seq_len ), persistent = False
6063 )
6164 self .batch_size = batch_size
6265
@@ -105,17 +108,17 @@ def update(
105108 f", but found new key tensors with batch size { k_val .shape [0 ]} !"
106109 )
107110
108- assert (self .cache_pos [0 ] + seq_len ) <= self .max_seq_len
111+ assert (self .kv_cache_pos [0 ] + seq_len ) <= self .max_seq_len
109112
110113 k_out = self .k_cache
111114 v_out = self .v_cache
112115
113116 if self .transpose_cache :
114- k_out [:, :, self .cache_pos [:seq_len ]] = k_val
115- v_out [:, :, self .cache_pos [:seq_len ]] = v_val
117+ k_out [:, :, self .kv_cache_pos [:seq_len ]] = k_val
118+ v_out [:, :, self .kv_cache_pos [:seq_len ]] = v_val
116119 else :
117- k_out [:, self .cache_pos [:seq_len ]] = k_val
118- v_out [:, self .cache_pos [:seq_len ]] = v_val
120+ k_out [:, self .kv_cache_pos [:seq_len ]] = k_val
121+ v_out [:, self .kv_cache_pos [:seq_len ]] = v_val
119122
120123 # forward cache_pos seq_len positions along
121124 # cache_pos starts at (0, 1, 2, 3, 4, 5, ...)
@@ -124,7 +127,7 @@ def update(
124127 # this allows us to track the current position in the cache
125128 # after the last update in a compile-friendly way without any dynamism
126129 # e.g. relying on an int size tracker, or re-creating cache_pos every time
127- self .cache_pos .add_ (seq_len )
130+ self .kv_cache_pos .add_ (seq_len )
128131
129132 return k_out , v_out
130133
@@ -144,5 +147,5 @@ def clone(self) -> "KVCache":
144147 )
145148 clone .k_cache .copy_ (self .k_cache )
146149 clone .v_cache .copy_ (self .v_cache )
147- clone .cache_pos .copy_ (self .cache_pos )
150+ clone .kv_cache_pos .copy_ (self .kv_cache_pos )
148151 return clone
0 commit comments