Skip to content

Commit 9f43ab7

Browse files
committed
Dont quantize the current token for attention
Differential Revision: [D63497872](https://our.internmc.facebook.com/intern/diff/D63497872/) ghstack-source-id: 244978598 Pull Request resolved: #5715
1 parent 4eeee07 commit 9f43ab7

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

examples/models/llama2/source_transformation/quantized_kv_cache.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,27 @@ def update(self, input_pos, k_val, v_val):
189189
self.quantized_cache_dtype,
190190
self.cache_fp_type,
191191
)
192+
193+
if self.is_transposed:
194+
if self.enable_dynamic_shape:
195+
start_pos = input_pos[0].item()
196+
torch._check_is_size(start_pos)
197+
dim_to_slice = 2 if self.is_transposed else 1
198+
torch._check(start_pos < self.k_cache.size(dim_to_slice))
199+
seq_length = k_val.size(dim_to_slice)
200+
narrowed_k = k_out.narrow(dim_to_slice, start_pos, seq_length)
201+
narrowed_k.copy_(k_val)
202+
# pyre-ignore: Incompatible parameter type [6]
203+
narrowed_v = v_out.narrow(dim_to_slice, start_pos, seq_length)
204+
narrowed_v.copy_(v_val)
205+
else:
206+
k_out[:, :, input_pos] = k_val
207+
v_out[:, :, input_pos] = v_val
208+
else:
209+
start_pos = input_pos[0].item()
210+
_ = torch.ops.llama.update_quantized_cache(k_val, k_out, start_pos)
211+
_ = torch.ops.llama.update_quantized_cache(v_val, v_out, start_pos)
212+
192213
return k_out, v_out
193214

194215
@classmethod

0 commit comments

Comments
 (0)