Skip to content

Commit adfc140

Browse files
committed
Dont quantize the current token for attention
Pull Request resolved: #5715 ghstack-source-id: 245751541 @exported-using-ghexport Differential Revision: [D63497872](https://our.internmc.facebook.com/intern/diff/D63497872/)
1 parent ba3a9a3 commit adfc140

File tree

2 files changed

+21
-6
lines changed

2 files changed

+21
-6
lines changed

examples/models/llama2/source_transformation/quantized_kv_cache.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,27 @@ def update(self, input_pos, k_val, v_val):
188188
self.quantized_cache_dtype,
189189
self.cache_fp_type,
190190
)
191+
192+
if self.is_transposed:
193+
if self.enable_dynamic_shape:
194+
start_pos = input_pos[0].item()
195+
torch._check_is_size(start_pos)
196+
dim_to_slice = 2 if self.is_transposed else 1
197+
torch._check(start_pos < self.k_cache.size(dim_to_slice))
198+
seq_length = k_val.size(dim_to_slice)
199+
narrowed_k = k_out.narrow(dim_to_slice, start_pos, seq_length)
200+
narrowed_k.copy_(k_val)
201+
# pyre-ignore: Incompatible parameter type [6]
202+
narrowed_v = v_out.narrow(dim_to_slice, start_pos, seq_length)
203+
narrowed_v.copy_(v_val)
204+
else:
205+
k_out[:, :, input_pos] = k_val
206+
v_out[:, :, input_pos] = v_val
207+
else:
208+
start_pos = input_pos[0].item()
209+
_ = torch.ops.llama.update_quantized_cache(k_val, k_out, start_pos)
210+
_ = torch.ops.llama.update_quantized_cache(v_val, v_out, start_pos)
211+
191212
return k_out, v_out
192213

193214
@classmethod

examples/models/llama2/source_transformation/test_sdpa_with_quantized_kv_cache.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,6 @@ def test_simple(self, is_dynamic_shape=False):
6666
torch.testing.assert_close(
6767
float_out,
6868
quantized_out,
69-
# had to adjust rtol because switching to using custom_sdpa means we
70-
# will use dequantized k and v instead of original k and v
71-
# this leads to larger differences in the output.
72-
# subsequent diff in the stack will address this issue.
73-
rtol=1e-01,
74-
atol=1e-03,
7569
)
7670

7771
input_pos = torch.tensor([3], dtype=torch.int64)

0 commit comments

Comments
 (0)