Skip to content

Commit d51ef81

Browse files
committed
[Executorch][kv cache] Make quantized cache return only the updated cache
portion Differential Revision: [D69856554](https://our.internmc.facebook.com/intern/diff/D69856554/) ghstack-source-id: 272375854 Pull Request resolved: #9351
1 parent 80ec397 commit d51ef81

File tree

1 file changed

+24
-6
lines changed

1 file changed

+24
-6
lines changed

examples/models/llama/source_transformation/quantized_kv_cache.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ def __init__(
4848
f"Only affine symmetric and asymmetric cache types are supported: got {cache_type}"
4949
)
5050

51+
# This reduces dequantize overhead by dequantizing only upto the
52+
# current input position.
53+
self.return_sliced_cache = True
5154
# For now supporting int8 only
5255
self.use_custom_update_cache_op = use_custom_update_cache_op
5356
self.quantized_cache_dtype = torch.int8
@@ -125,19 +128,34 @@ def update(self, input_pos, k_val, v_val):
125128
self.v_cache_scales[:, input_pos] = v_scales
126129
self.v_cache_zero_points[:, input_pos] = v_zero_points
127130

131+
if self.return_sliced_cache:
132+
k_cache_out = self.k_cache[:, :input_pos+1]
133+
k_cache_scales = self.k_cache_scales[:, :input_pos+1]
134+
k_cache_zero_points = self.k_cache_zero_points[:, :input_pos+1]
135+
v_cache_out = self.v_cache[:, :input_pos+1]
136+
v_cache_scales = self.v_cache_scales[:, :input_pos+1]
137+
v_cache_zero_points = self.v_cache_zero_points[:, :input_pos+1]
138+
else:
139+
k_cache_out = self.k_cache
140+
k_cache_scales = self.k_cache_scales
141+
k_cache_zero_points = self.k_cache_zero_points
142+
v_cache_out = self.v_cache
143+
v_cache_scales = self.v_cache_scales
144+
v_cache_zero_points = self.v_cache_zero_points
145+
128146
k_out = torch.ops.quantized_decomposed.dequantize_per_token(
129-
self.k_cache,
130-
self.k_cache_scales,
131-
self.k_cache_zero_points,
147+
k_cache_out,
148+
k_cache_scales,
149+
k_cache_zero_points,
132150
torch.iinfo(self.quantized_cache_dtype).min,
133151
torch.iinfo(self.quantized_cache_dtype).max,
134152
self.quantized_cache_dtype,
135153
self.cache_fp_type,
136154
)
137155
v_out = torch.ops.quantized_decomposed.dequantize_per_token(
138-
self.v_cache,
139-
self.v_cache_scales,
140-
self.v_cache_zero_points,
156+
v_cache_out,
157+
v_cache_scales,
158+
v_cache_zero_points,
141159
torch.iinfo(self.quantized_cache_dtype).min,
142160
torch.iinfo(self.quantized_cache_dtype).max,
143161
self.quantized_cache_dtype,

0 commit comments

Comments
 (0)