Skip to content

Commit f45aebb

Browse files
committed
[Executorch][llama] Update SDPA op to use quantized kv cache
Pull Request resolved: #5666 Using quantized kv cache, we cannot rely on sdpa to update the original case. SO we insert cache update op ghstack-source-id: 245718145 @exported-using-ghexport Differential Revision: [D62301841](https://our.internmc.facebook.com/intern/diff/D62301841/)
1 parent d6d5b67 commit f45aebb

File tree

5 files changed

+205
-45
lines changed

5 files changed

+205
-45
lines changed

examples/models/llama2/TARGETS

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,13 +160,44 @@ runtime.python_library(
160160
],
161161
)
162162

163+
runtime.python_library(
164+
name = "sdpa",
165+
srcs = [
166+
"source_transformation/sdpa.py",
167+
],
168+
_is_external_target = True,
169+
visibility = ["//executorch/..."],
170+
deps = [
171+
"//caffe2:torch",
172+
],
173+
)
174+
163175
runtime.python_test(
164176
name = "quantized_kv_cache_test",
165177
srcs = [
166178
"source_transformation/test_quantized_kv_cache.py",
167179
],
180+
preload_deps = [
181+
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
182+
],
183+
deps = [
184+
":quantized_kv_cache",
185+
"//caffe2:torch",
186+
"//executorch/examples/models/llama2:llama_transformer",
187+
],
188+
)
189+
190+
runtime.python_test(
191+
name = "quantized_sdpa_with_kv_cache_test",
192+
srcs = [
193+
"source_transformation/test_sdpa_with_quantized_kv_cache.py",
194+
],
195+
preload_deps = [
196+
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
197+
],
168198
deps = [
169199
":quantized_kv_cache",
200+
":sdpa",
170201
"//caffe2:torch",
171202
"//executorch/examples/models/llama2:llama_transformer",
172203
],

examples/models/llama2/export_llama_lib.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -890,9 +890,7 @@ def _get_source_transforms( # noqa
890890
transforms.append(replace_sdpa_with_custom_op)
891891

892892
if args.quantize_kv_cache:
893-
assert (
894-
args.use_kv_cache and not args.use_sdpa_with_kv_cache
895-
), "quantize_kv_cache requires use_kv_cache=True and use_sdpa_with_kv_cache=False"
893+
assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"
896894
transforms.append(replace_kv_cache_with_quantized_kv_cache)
897895

898896
if args.use_kv_cache:

examples/models/llama2/source_transformation/quantized_kv_cache.py

Lines changed: 65 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def __init__(
4747
raise ValueError(
4848
f"Only affine symmetric and asymmetric cache types are supported: got {cache_type}"
4949
)
50+
5051
# For now supporting int8 only
5152
self.quantized_cache_dtype = torch.int8
5253
self.cache_fp_type = torch.float32
@@ -65,10 +66,10 @@ def __init__(
6566
"v_cache", torch.zeros(cache_shape, dtype=self.quantized_cache_dtype)
6667
)
6768
self.register_buffer(
68-
"k_cache_scales", torch.ones(scale_shape, dtype=torch.double)
69+
"k_cache_scales", torch.ones(scale_shape, dtype=torch.float64)
6970
)
7071
self.register_buffer(
71-
"v_cache_scales", torch.ones(scale_shape, dtype=torch.double)
72+
"v_cache_scales", torch.ones(scale_shape, dtype=torch.float64)
7273
)
7374
if cache_type == QuantizedCacheType.AffineAsymmetric:
7475
self.register_buffer(
@@ -100,47 +101,74 @@ def update(self, input_pos, k_val, v_val):
100101

101102
quantized_v_val, v_scales, v_zero_points = self._quantize(v_val)
102103

103-
if self.enable_dynamic_shape:
104-
start_pos = input_pos[0].item()
105-
torch._check_is_size(start_pos)
106-
dim_to_slice = 2 if self.is_transposed else 1
107-
torch._check(start_pos < self.k_cache.size(dim_to_slice))
108-
seq_length = k_val.size(dim_to_slice)
109-
narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length)
110-
narrowed_k_scales = self.k_cache_scales.narrow(
111-
dim_to_slice, start_pos, seq_length
112-
)
113-
narrowed_k_zp = self.k_cache_zero_points.narrow(
114-
dim_to_slice, start_pos, seq_length
115-
)
116-
narrowed_k.copy_(quantized_k_val)
117-
narrowed_k_scales.copy_(k_scales)
118-
narrowed_k_zp.copy_(k_zero_points)
119-
narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length)
120-
narrowed_v_scales = self.v_cache_scales.narrow(
121-
dim_to_slice, start_pos, seq_length
122-
)
123-
narrowed_v_zp = self.v_cache_zero_points.narrow(
124-
dim_to_slice, start_pos, seq_length
125-
)
126-
narrowed_v.copy_(quantized_v_val)
127-
narrowed_v_scales.copy_(v_scales)
128-
narrowed_v_zp.copy_(v_zero_points)
129-
else:
130-
if self.is_transposed:
104+
if self.is_transposed:
105+
# We cannot use update_cache op at the moment
106+
# if the cache is transposed
107+
# Also note that we shold not need separate paths
108+
# for dynamic shape vs !
109+
# Only reason it is done this way is to accommodate
110+
# for lowering pains of backends that work better
111+
# with index_put op.
112+
if self.enable_dynamic_shape:
113+
start_pos = input_pos[0].item()
114+
torch._check_is_size(start_pos)
115+
dim_to_slice = 2 if self.is_transposed else 1
116+
torch._check(start_pos < self.k_cache.size(dim_to_slice))
117+
seq_length = k_val.size(dim_to_slice)
118+
narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length)
119+
narrowed_k_scales = self.k_cache_scales.narrow(
120+
dim_to_slice, start_pos, seq_length
121+
)
122+
narrowed_k_zp = self.k_cache_zero_points.narrow(
123+
dim_to_slice, start_pos, seq_length
124+
)
125+
narrowed_k.copy_(quantized_k_val)
126+
narrowed_k_scales.copy_(k_scales)
127+
narrowed_k_zp.copy_(k_zero_points)
128+
narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length)
129+
narrowed_v_scales = self.v_cache_scales.narrow(
130+
dim_to_slice, start_pos, seq_length
131+
)
132+
narrowed_v_zp = self.v_cache_zero_points.narrow(
133+
dim_to_slice, start_pos, seq_length
134+
)
135+
narrowed_v.copy_(quantized_v_val)
136+
narrowed_v_scales.copy_(v_scales)
137+
narrowed_v_zp.copy_(v_zero_points)
138+
else:
131139
self.k_cache[:, :, input_pos] = quantized_k_val
132140
self.k_cache_scales[:, :, input_pos] = k_scales
133141
self.k_cache_zero_points[:, :, input_pos] = k_zero_points
134142
self.v_cache[:, :, input_pos] = quantized_v_val
135143
self.v_cache_scales[:, :, input_pos] = v_scales
136144
self.v_cache_zero_points[:, :, input_pos] = v_zero_points
137-
else:
138-
self.k_cache[:, input_pos] = quantized_k_val
139-
self.k_cache_scales[:, input_pos] = k_scales
140-
self.k_cache_zero_points[:, input_pos] = k_zero_points
141-
self.v_cache[:, input_pos] = quantized_v_val
142-
self.v_cache_scales[:, input_pos] = v_scales
143-
self.v_cache_zero_points[:, input_pos] = v_zero_points
145+
else:
146+
# Right now using custom ops on this path.
147+
# In future we can update custom op to handle transposed cache
148+
# as well.
149+
# Note that we may have to revert this change if other ET
150+
# backends such as QNN want to use quantized cache, with dynamic shape,
151+
# instead of quantizing on their own.
152+
# But until this opting for code simplicity
153+
start_pos = input_pos[0].item()
154+
_ = torch.ops.llama.update_quantized_cache(
155+
quantized_k_val, self.k_cache, start_pos
156+
)
157+
_ = torch.ops.llama.update_quantized_cache(
158+
k_scales, self.k_cache_scales, start_pos
159+
)
160+
_ = torch.ops.llama.update_quantized_cache(
161+
k_zero_points, self.k_cache_zero_points, start_pos
162+
)
163+
_ = torch.ops.llama.update_quantized_cache(
164+
quantized_v_val, self.v_cache, start_pos
165+
)
166+
_ = torch.ops.llama.update_quantized_cache(
167+
v_scales, self.v_cache_scales, start_pos
168+
)
169+
_ = torch.ops.llama.update_quantized_cache(
170+
v_zero_points, self.v_cache_zero_points, start_pos
171+
)
144172

145173
k_out = torch.ops.quantized_decomposed.dequantize_per_token(
146174
self.k_cache,

examples/models/llama2/source_transformation/sdpa.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,32 @@
99
# Example script for exporting Llama2 to flatbuffer
1010

1111
import math
12-
from typing import Tuple
12+
from typing import Tuple, Union
1313

1414
import torch
1515

1616
from executorch.examples.models.llama2.llama_transformer import KVCache, SDPA
17+
from executorch.examples.models.llama2.source_transformation.quantized_kv_cache import (
18+
QuantizedKVCache,
19+
)
1720

1821

1922
class SDPACustom(torch.nn.Module):
2023
def __init__(
2124
self,
22-
kv_cache: KVCache,
25+
kv_cache: Union[KVCache, QuantizedKVCache],
2326
dim: int,
2427
):
2528
super().__init__()
2629
# Custom op only supports float32 currently. Converting to/from float32 is
2730
# faster than not having the op.
28-
self.kv_cache = kv_cache.to(torch.float)
31+
self.kv_cache = kv_cache
32+
if not isinstance(kv_cache, QuantizedKVCache):
33+
self.kv_cache = kv_cache.to(torch.float)
34+
else:
35+
assert (
36+
kv_cache.cache_fp_type == torch.float32
37+
), "Only float32 is supported for custom SDPA"
2938
self.dim = dim
3039

3140
def forward(
@@ -44,12 +53,27 @@ def forward(
4453
q = q.to(dtype=torch.float)
4554
k = k.to(dtype=torch.float)
4655
v = v.to(dtype=torch.float)
56+
57+
k_cache = self.kv_cache.k_cache
58+
v_cache = self.kv_cache.v_cache
59+
if isinstance(self.kv_cache, QuantizedKVCache):
60+
# updated quantize cache, scale and zero points
61+
# returns dequantized kv cache
62+
# Not most optimal. Optimizations to follow next
63+
k_cache, v_cache = self.kv_cache.update(input_pos, k, v)
64+
# Note that this path will still inplace mutate the k_cache, v_cache.
65+
# WHen we are not using quantized kv cache, this will just mutate
66+
# the original kv cache.
67+
# When we aer using quantized kv cache, this will mutate
68+
# k_cache, v_cache that is returned from cache update operation.
69+
# This operation just dequantized thee cache and returns that.
70+
# Future diffs will optimize this
4771
output = torch.ops.llama.sdpa_with_kv_cache(
4872
q,
4973
k,
5074
v,
51-
self.kv_cache.k_cache,
52-
self.kv_cache.v_cache,
75+
k_cache,
76+
v_cache,
5377
input_pos[-1].item(),
5478
seqlen,
5579
None, # Attention mask
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import unittest
2+
3+
import torch
4+
5+
from executorch.examples.models.llama2.llama_transformer import KVCache
6+
7+
from executorch.examples.models.llama2.source_transformation.quantized_kv_cache import (
8+
QuantizedCacheType,
9+
QuantizedKVCache,
10+
)
11+
12+
from executorch.examples.models.llama2.source_transformation.sdpa import SDPACustom
13+
14+
15+
class SDPAWithQuantizedKVCacheTest(unittest.TestCase):
16+
17+
def _init_cache(self):
18+
self.kv_cache = KVCache(
19+
self.max_batch_size,
20+
self.max_seq_len,
21+
self.n_kv_heads,
22+
self.head_dim,
23+
False,
24+
self.enable_dynamic_shape,
25+
dtype=self.dtype,
26+
)
27+
self.quantized_kv_cache = QuantizedKVCache.from_float(
28+
self.kv_cache, QuantizedCacheType.AffineAsymmetric
29+
)
30+
31+
def _init_kv(self):
32+
kv_shape = (1, self.seq_len, self.n_kv_heads, self.head_dim)
33+
q_shape = (1, self.seq_len, self.n_heads, self.head_dim)
34+
q = torch.rand(q_shape, dtype=self.dtype)
35+
k = torch.rand(kv_shape, dtype=self.dtype)
36+
v = torch.rand(kv_shape, dtype=self.dtype)
37+
return q, k, v
38+
39+
def setUp(self):
40+
torch.manual_seed(42)
41+
self.max_batch_size = 1
42+
self.max_seq_len = 5
43+
self.n_kv_heads = 4
44+
self.n_heads = 8
45+
self.head_dim = 17
46+
self.dim = self.n_heads * self.head_dim
47+
self.enable_dynamic_shape = False
48+
self.dtype = torch.float32
49+
50+
def test_simple(self, is_dynamic_shape=False):
51+
self.enable_dynamic_shape = is_dynamic_shape
52+
input_pos = torch.tensor([0], dtype=torch.int64)
53+
self.seq_len = 3
54+
self._init_cache()
55+
q, k, v = self._init_kv()
56+
self.float_sdpa = SDPACustom(self.kv_cache, self.dim)
57+
self.quantized_sdpa = SDPACustom(self.quantized_kv_cache, self.dim)
58+
float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
59+
quantized_out = self.quantized_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
60+
self.assertTrue(
61+
torch.allclose(
62+
float_out,
63+
quantized_out,
64+
)
65+
)
66+
67+
input_pos = torch.tensor([3], dtype=torch.int64)
68+
self.seq_len = 1
69+
q, k, v = self._init_kv()
70+
float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
71+
quantized_out = self.quantized_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
72+
self.assertTrue(
73+
torch.allclose(
74+
float_out,
75+
quantized_out,
76+
rtol=1e-03,
77+
atol=1e-03,
78+
)
79+
)

0 commit comments

Comments
 (0)