Skip to content

Commit 3276197

Browse files
committed
[ExecuTorch] Some updated to kv cache
Update kv cache impl to consider untransposed cache Differential Revision: [D62301843](https://our.internmc.facebook.com/intern/diff/D62301843/) [ghstack-poisoned]
1 parent 3512148 commit 3276197

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

examples/models/llama2/llama_transformer.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def __init__(
151151
):
152152
super().__init__()
153153
self.max_seq_length = max_seq_length
154+
self.is_tranposed = transpose_cache
154155
if transpose_cache:
155156
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
156157
else:
@@ -173,28 +174,37 @@ def update(
173174
) -> Tuple[torch.Tensor, torch.Tensor]:
174175
# input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache
175176
if self.enable_dynamic_shape:
176-
start_pos = input_pos[-1].item()
177+
start_pos = input_pos[0].item()
177178
torch._check_is_size(start_pos)
178179
torch._check(start_pos < self.max_seq_length)
179-
seq_length = k_val.size(2)
180+
if self.transpose_cache:
181+
dim_to_slice = 2
182+
else:
183+
dim_to_slice = 1
184+
seq_length = k_val.size(dim_to_slice)
180185
# Replace the entry in the cache for this token
181186
# The following lines are equivalent to:
182187
# cache_k[:bsz, start_pos : start_pos + seqlen] = xk
183188
# cache_v[:bsz, start_pos : start_pos + seqlen] = xv
189+
# when dim_to_slice is 1
184190
# We use .narrow() here to make the compiler happy
185191
# pyre-ignore: Incompatible parameter type [6]
186-
narrowed_k = self.k_cache.narrow(2, start_pos, seq_length)
192+
narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length)
187193
# pyre-ignore: Incompatible parameter type [6]
188-
narrowed_v = self.v_cache.narrow(2, start_pos, seq_length)
194+
narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length)
189195

190196
narrowed_k.copy_(k_val)
191197
narrowed_v.copy_(v_val)
192198
return self.k_cache, self.v_cache
193199
else:
194200
k_out = self.k_cache
195201
v_out = self.v_cache
196-
k_out[:, :, input_pos] = k_val
197-
v_out[:, :, input_pos] = v_val
202+
if self.transpose_cache:
203+
k_out[:, :, input_pos] = k_val
204+
v_out[:, :, input_pos] = v_val
205+
else:
206+
k_out[:, input_pos] = k_val
207+
v_out[:, input_pos] = v_val
198208

199209
return k_out, v_out
200210

0 commit comments

Comments
 (0)