Skip to content

Commit ff9ddd6

Browse files
committed
[Executorch][llama] Update SDPA op to use quantized kv cache
Using quantized kv cache, we cannot rely on sdpa to update the original case. SO we insert cache update op Differential Revision: [D62301841](https://our.internmc.facebook.com/intern/diff/D62301841/) ghstack-source-id: 243859224 Pull Request resolved: #5528
1 parent d364c67 commit ff9ddd6

File tree

5 files changed

+90
-43
lines changed

5 files changed

+90
-43
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -474,9 +474,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
474474
transforms.append(replace_sdpa_with_custom_op)
475475

476476
if args.quantize_kv_cache:
477-
assert (args.use_kv_cache is True) and (
478-
args.use_sdpa_with_kv_cache is False
479-
), "quantize_kv_cache requires use_kv_cache=True and use_sdpa_with_kv_cache=False"
477+
assert args.use_kv_cache is True, "quantize_kv_cache requires use_kv_cache=True"
480478
transforms.append(replace_kv_cache_with_quantized_kv_cache)
481479

482480
if args.use_kv_cache:

examples/models/llama2/source_transformation/TARGETS

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ runtime.python_test(
1818
srcs = [
1919
"test_quantized_kv_cache.py",
2020
],
21+
preload_deps = [
22+
"//executorch/extension/llm/custom_ops:update_quantized_cache_aot_lib",
23+
],
2124
deps = [
2225
":quantized_kv_cache",
2326
"//caffe2:torch",

examples/models/llama2/source_transformation/quantized_kv_cache.py

Lines changed: 66 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(
4040
raise ValueError(
4141
f"Only affine symmetric and asymmetric cache types are supported: got {cache_type}"
4242
)
43+
4344
# For now supporting int8 only
4445
self.quantized_cache_dtype = torch.int8
4546
self.cache_fp_type = torch.float32
@@ -97,51 +98,78 @@ def update(self, input_pos, k_val, v_val):
9798
torch.int8,
9899
)
99100

100-
if self.enable_dynamic_shape:
101-
start_pos = input_pos[0].item()
102-
torch._check_is_size(start_pos)
103-
if self.is_transposed:
104-
dim_to_slice = 2
101+
if self.is_transposed:
102+
# We cannot use update_cache op at the moment
103+
# if the cache is transposed
104+
# Also note that we shold not need separate paths
105+
# for dynamic shape vs !
106+
# Only reason it is done this way is to accommodate
107+
# for lowering pains of backends that work better
108+
# with index_put op.
109+
if self.enable_dynamic_shape:
110+
start_pos = input_pos[0].item()
111+
torch._check_is_size(start_pos)
112+
if self.is_transposed:
113+
dim_to_slice = 2
114+
else:
115+
dim_to_slice = 1
116+
torch._check(start_pos < self.k_cache.size(dim_to_slice))
117+
seq_length = k_val.size(dim_to_slice)
118+
narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length)
119+
narrowed_k_scales = self.k_cache_scales.narrow(
120+
dim_to_slice, start_pos, seq_length
121+
)
122+
narrowed_k_zp = self.k_cache_zero_points.narrow(
123+
dim_to_slice, start_pos, seq_length
124+
)
125+
narrowed_k.copy_(quantized_k_val)
126+
narrowed_k_scales.copy_(k_scales)
127+
narrowed_k_zp.copy_(k_zero_points)
128+
# pyre-ignore: Incompatible parameter type [6]
129+
narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length)
130+
narrowed_v_scales = self.v_cache_scales.narrow(
131+
dim_to_slice, start_pos, seq_length
132+
)
133+
narrowed_v_zp = self.v_cache_zero_points.narrow(
134+
dim_to_slice, start_pos, seq_length
135+
)
136+
narrowed_v.copy_(quantized_v_val)
137+
narrowed_v_scales.copy_(v_scales)
138+
narrowed_v_zp.copy_(v_zero_points)
105139
else:
106-
dim_to_slice = 1
107-
torch._check(start_pos < self.k_cache.size(dim_to_slice))
108-
seq_length = k_val.size(dim_to_slice)
109-
narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length)
110-
narrowed_k_scales = self.k_cache_scales.narrow(
111-
dim_to_slice, start_pos, seq_length
112-
)
113-
narrowed_k_zp = self.k_cache_zero_points.narrow(
114-
dim_to_slice, start_pos, seq_length
115-
)
116-
narrowed_k.copy_(quantized_k_val)
117-
narrowed_k_scales.copy_(k_scales)
118-
narrowed_k_zp.copy_(k_zero_points)
119-
# pyre-ignore: Incompatible parameter type [6]
120-
narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length)
121-
narrowed_v_scales = self.v_cache_scales.narrow(
122-
dim_to_slice, start_pos, seq_length
123-
)
124-
narrowed_v_zp = self.v_cache_zero_points.narrow(
125-
dim_to_slice, start_pos, seq_length
126-
)
127-
narrowed_v.copy_(quantized_v_val)
128-
narrowed_v_scales.copy_(v_scales)
129-
narrowed_v_zp.copy_(v_zero_points)
130-
else:
131-
if self.is_transposed:
132140
self.k_cache[:, :, input_pos] = quantized_k_val
133141
self.k_cache_scales[:, :, input_pos] = k_scales
134142
self.k_cache_zero_points[:, :, input_pos] = k_zero_points
135143
self.v_cache[:, :, input_pos] = quantized_v_val
136144
self.v_cache_scales[:, :, input_pos] = v_scales
137145
self.v_cache_zero_points[:, :, input_pos] = v_zero_points
138-
else:
139-
self.k_cache[:, input_pos] = quantized_k_val
140-
self.k_cache_scales[:, input_pos] = k_scales
141-
self.k_cache_zero_points[:, input_pos] = k_zero_points
142-
self.v_cache[:, input_pos] = quantized_v_val
143-
self.v_cache_scales[:, input_pos] = v_scales
144-
self.v_cache_zero_points[:, input_pos] = v_zero_points
146+
else:
147+
# Right now using custom ops on this path.
148+
# In future we can update custom op to handle transposed cache
149+
# as well.
150+
# Note that we may have to revert this change if other ET
151+
# backends such as QNN want to use quantized cache, with dynamic shape,
152+
# instead of quantizing on their own.
153+
# But until this opting for code simplicity
154+
start_pos = input_pos[0].item()
155+
_ = torch.ops.llama.update_quantized_cache(
156+
quantized_k_val, self.k_cache, start_pos
157+
)
158+
_ = torch.ops.llama.update_quantized_cache(
159+
k_scales, self.k_cache_scales, start_pos
160+
)
161+
_ = torch.ops.llama.update_quantized_cache(
162+
k_zero_points, self.k_cache_zero_points, start_pos
163+
)
164+
_ = torch.ops.llama.update_quantized_cache(
165+
quantized_v_val, self.v_cache, start_pos
166+
)
167+
_ = torch.ops.llama.update_quantized_cache(
168+
v_scales, self.v_cache_scales, start_pos
169+
)
170+
_ = torch.ops.llama.update_quantized_cache(
171+
v_zero_points, self.v_cache_zero_points, start_pos
172+
)
145173

146174
k_out = torch.ops.quantized_decomposed.dequantize_per_token(
147175
self.k_cache,

examples/models/llama2/source_transformation/sdpa.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
import torch
1515

1616
from executorch.examples.models.llama2.llama_transformer import KVCache, SDPA
17+
from executorch.examples.models.llama2.source_transformation.quantized_kv_cache import (
18+
QuantizedKVCache,
19+
)
1720

1821

1922
class SDPACustom(torch.nn.Module):
@@ -36,12 +39,26 @@ def forward(
3639
seqlen,
3740
mask,
3841
):
42+
k_cache = self.kv_cache.k_cache
43+
v_cache = self.kv_cache.v_cache
44+
if isinstance(self.kv_cache, QuantizedKVCache):
45+
# updated quantize cache, scale and zero points
46+
# returns dequantized kv cache
47+
# Not most optimal. Optimizations to follow next
48+
k_cache, v_cache = self.kv_cache.update(input_pos, k, v)
49+
# Note that this path will still inplace mutate the k_cache, v_cache.
50+
# WHen we are not using quantized kv cache, this will just mutate
51+
# the original kv cache.
52+
# When we aer using quantized kv cache, this will mutate
53+
# k_cache, v_cache that is returned from cache update operation.
54+
# This operation just dequantized thee cache and returns that.
55+
# Future diffs will optimize this
3956
output = torch.ops.llama.sdpa_with_kv_cache(
4057
q,
4158
k,
4259
v,
43-
self.kv_cache.k_cache,
44-
self.kv_cache.v_cache,
60+
k_cache,
61+
v_cache,
4562
input_pos[-1].item(),
4663
seqlen,
4764
None, # Attention mask

extension/llm/custom_ops/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def define_common_targets():
2020
"op_sdpa.h",
2121
],
2222
exported_deps = [
23+
":update_quantized_cache",
2324
"//executorch/runtime/kernel:kernel_includes",
2425
"//executorch/kernels/portable/cpu:scalar_utils",
2526
"//executorch/kernels/optimized:libblas{}".format(mkl_dep),

0 commit comments

Comments
 (0)