@@ -151,6 +151,7 @@ def __init__(
151151 ):
152152 super ().__init__ ()
153153 self .max_seq_length = max_seq_length
154+ self .is_tranposed = transpose_cache
154155 if transpose_cache :
155156 cache_shape = (max_batch_size , n_heads , max_seq_length , head_dim )
156157 else :
@@ -173,28 +174,34 @@ def update(
173174 ) -> Tuple [torch .Tensor , torch .Tensor ]:
174175 # input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache
175176 if self .enable_dynamic_shape :
176- start_pos = input_pos [- 1 ].item ()
177+ start_pos = input_pos [0 ].item ()
177178 torch ._check_is_size (start_pos )
178179 torch ._check (start_pos < self .max_seq_length )
179- seq_length = k_val .size (2 )
180+ dim_to_slice = 2 if self .transpose_cache else 1
181+ seq_length = k_val .size (dim_to_slice )
180182 # Replace the entry in the cache for this token
181183 # The following lines are equivalent to:
182184 # cache_k[:bsz, start_pos : start_pos + seqlen] = xk
183185 # cache_v[:bsz, start_pos : start_pos + seqlen] = xv
186+ # when dim_to_slice is 1
184187 # We use .narrow() here to make the compiler happy
185188 # pyre-ignore: Incompatible parameter type [6]
186- narrowed_k = self .k_cache .narrow (2 , start_pos , seq_length )
189+ narrowed_k = self .k_cache .narrow (dim_to_slice , start_pos , seq_length )
187190 # pyre-ignore: Incompatible parameter type [6]
188- narrowed_v = self .v_cache .narrow (2 , start_pos , seq_length )
191+ narrowed_v = self .v_cache .narrow (dim_to_slice , start_pos , seq_length )
189192
190193 narrowed_k .copy_ (k_val )
191194 narrowed_v .copy_ (v_val )
192195 return self .k_cache , self .v_cache
193196 else :
194197 k_out = self .k_cache
195198 v_out = self .v_cache
196- k_out [:, :, input_pos ] = k_val
197- v_out [:, :, input_pos ] = v_val
199+ if self .transpose_cache :
200+ k_out [:, :, input_pos ] = k_val
201+ v_out [:, :, input_pos ] = v_val
202+ else :
203+ k_out [:, input_pos ] = k_val
204+ v_out [:, input_pos ] = v_val
198205
199206 return k_out , v_out
200207
0 commit comments