Skip to content

Commit 06a1f07

Browse files
committed
[ExecuTorch] Add quantized kv cache to llama
This diff adds - quantized kv cache imlementation and apply corresponding source transforms - add support for quant/dequant per token in quantized kernels Differential Revision: [D62301844](https://our.internmc.facebook.com/intern/diff/D62301844/) ghstack-source-id: 243859219 Pull Request resolved: #5525
1 parent 3fbc178 commit 06a1f07

File tree

15 files changed

+1157
-35
lines changed

15 files changed

+1157
-35
lines changed

examples/models/llama2/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ runtime.python_library(
7373
"source_transformation/apply_spin_quant_r1_r2.py",
7474
"source_transformation/prune_output.py",
7575
"source_transformation/quantize.py",
76+
"source_transformation/quantized_kv_cache.py",
7677
"source_transformation/rms_norm.py",
7778
"source_transformation/rope.py",
7879
"source_transformation/sdpa.py",

examples/models/llama2/export_llama_lib.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,11 @@
5353
get_quant_embedding_transform,
5454
get_quant_weight_transform,
5555
)
56+
from .source_transformation.quantized_kv_cache import (
57+
replace_kv_cache_with_quantized_kv_cache,
58+
)
5659
from .source_transformation.rms_norm import replace_rms_norm_with_native_rms_norm
60+
5761
from .source_transformation.rope import materialze_broadcast_of_rope_freq_cis
5862
from .source_transformation.sdpa import (
5963
replace_causal_mask,
@@ -206,6 +210,12 @@ def build_args_parser() -> argparse.ArgumentParser:
206210
action="store_true",
207211
help="Whether or not to export a model using kv cache",
208212
)
213+
parser.add_argument(
214+
"--quantize_kv_cache",
215+
default=False,
216+
action="store_true",
217+
help="Whether or not to export a model using quantized kv cache",
218+
)
209219
parser.add_argument(
210220
"--num_sharding",
211221
type=int,
@@ -428,7 +438,6 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
428438
429439
Returns a LLMEdgeManager prior to calling export_to_edge with quantizers
430440
"""
431-
432441
# load model from checkpoint and params.json
433442
checkpoint_path = canonical_path(args.checkpoint) if args.checkpoint else None
434443
checkpoint_dir = (
@@ -446,6 +455,41 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
446455
else:
447456
dtype_override = None
448457

458+
# source transforms
459+
transforms = []
460+
if args.quantization_mode:
461+
modelname = f"{modelname}_q"
462+
transforms.append(
463+
get_quant_weight_transform(args, dtype_override, verbose_export())
464+
)
465+
466+
if args.embedding_quantize:
467+
modelname = f"{modelname}_e"
468+
transforms.append(get_quant_embedding_transform(args))
469+
470+
if args.expand_rope_table:
471+
transforms.append(materialze_broadcast_of_rope_freq_cis)
472+
473+
if args.use_sdpa_with_kv_cache:
474+
transforms.append(replace_sdpa_with_custom_op)
475+
476+
if args.quantize_kv_cache:
477+
assert (args.use_kv_cache is True) and (
478+
args.use_sdpa_with_kv_cache is False
479+
), "quantize_kv_cache requires use_kv_cache=True and use_sdpa_with_kv_cache=False"
480+
transforms.append(replace_kv_cache_with_quantized_kv_cache)
481+
482+
if args.use_kv_cache:
483+
if args.qnn:
484+
transforms.append(replace_kv_cache_with_simple_kv_cache)
485+
transforms.append(replace_sdpa_with_flex_sdpa)
486+
transforms.append(replace_causal_mask)
487+
488+
elif args.coreml or args.mps:
489+
# Currently qnn/coreml/mps doesn't support sdpa op, use the simpler decomposition
490+
# to get free perf gain.
491+
transforms.append(replace_sdpa_with_simple_sdpa)
492+
transforms.append(replace_causal_mask)
449493
return (
450494
_load_llama_model(
451495
modelname=modelname,
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
runtime.python_library(
4+
name = "quantized_kv_cache",
5+
srcs = [
6+
"quantized_kv_cache.py",
7+
],
8+
_is_external_target = True,
9+
base_module = "executorch.examples.models.llama2.source_transformation",
10+
visibility = ["//executorch/..."],
11+
deps = [
12+
"//caffe2:torch",
13+
],
14+
)
15+
16+
runtime.python_test(
17+
name = "quantized_kv_cache_test",
18+
srcs = [
19+
"test_quantized_kv_cache.py",
20+
],
21+
deps = [
22+
":quantized_kv_cache",
23+
"//caffe2:torch",
24+
"//executorch/examples/models/llama2:llama_transformer",
25+
],
26+
)
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
import logging
2+
from enum import Enum
3+
4+
import torch
5+
import torch.nn as nn
6+
from executorch.examples.models.llama2.llama_transformer import KVCache
7+
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
8+
9+
10+
"""
11+
Heavily "inspired" by AO's implementation of the same in torchao/_models/llama/model.py
12+
"""
13+
14+
15+
# Doesnt have to abide by affine quantizaiton laws
16+
# However, if we do implement quantized sdpa, then this might be handy
17+
class QuantizedCacheType(Enum):
18+
AffineSymmetric = 0
19+
AffineAsymmetric = 1
20+
AffineSymmetricGroupWise = 1
21+
AffineAsymmetricGroupWise = 2
22+
23+
24+
class QuantizedKVCache(nn.Module):
25+
def __init__(
26+
self,
27+
max_batch_size,
28+
max_seq_length,
29+
n_heads,
30+
head_dim,
31+
cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric,
32+
tranposed=False,
33+
enable_dynamic_shape=False,
34+
):
35+
super().__init__()
36+
if not (
37+
cache_type == QuantizedCacheType.AffineSymmetric
38+
or cache_type == QuantizedCacheType.AffineAsymmetric
39+
):
40+
raise ValueError(
41+
f"Only affine symmetric and asymmetric cache types are supported: got {cache_type}"
42+
)
43+
# For now supporting int8 only
44+
self.quantized_cache_dtype = torch.int8
45+
self.cache_fp_type = torch.float32
46+
self.is_transposed = tranposed
47+
self.enable_dynamic_shape = enable_dynamic_shape
48+
if self.is_transposed:
49+
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
50+
scale_shape = (max_batch_size, n_heads, max_seq_length, 1)
51+
else:
52+
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
53+
scale_shape = (max_batch_size, max_seq_length, n_heads, 1)
54+
self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=torch.int8))
55+
self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=torch.int8))
56+
self.register_buffer(
57+
"k_cache_scales", torch.ones(scale_shape, dtype=torch.double)
58+
)
59+
self.register_buffer(
60+
"v_cache_scales", torch.ones(scale_shape, dtype=torch.double)
61+
)
62+
if cache_type == QuantizedCacheType.AffineAsymmetric:
63+
self.register_buffer(
64+
"k_cache_zero_points", torch.ones(scale_shape, dtype=torch.int64)
65+
)
66+
self.register_buffer(
67+
"v_cache_zero_points", torch.ones(scale_shape, dtype=torch.int64)
68+
)
69+
70+
def update(self, input_pos, k_val, v_val):
71+
# quantize current k_val and store it in the cache
72+
k_scales, k_zero_points = (
73+
torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(
74+
k_val, torch.int8 # no other value is supported by this op anyway
75+
)
76+
)
77+
quantized_k_val = torch.ops.quantized_decomposed.quantize_per_token(
78+
k_val,
79+
k_scales,
80+
k_zero_points,
81+
torch.iinfo(torch.int8).min,
82+
torch.iinfo(torch.int8).max,
83+
torch.int8,
84+
)
85+
86+
v_scales, v_zero_points = (
87+
torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric(
88+
v_val, torch.int8
89+
)
90+
)
91+
quantized_v_val = torch.ops.quantized_decomposed.quantize_per_token(
92+
v_val,
93+
v_scales,
94+
v_zero_points,
95+
torch.iinfo(torch.int8).min,
96+
torch.iinfo(torch.int8).max,
97+
torch.int8,
98+
)
99+
100+
if self.enable_dynamic_shape:
101+
start_pos = input_pos[0].item()
102+
torch._check_is_size(start_pos)
103+
if self.is_transposed:
104+
dim_to_slice = 2
105+
else:
106+
dim_to_slice = 1
107+
torch._check(start_pos < self.k_cache.size(dim_to_slice))
108+
seq_length = k_val.size(dim_to_slice)
109+
narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length)
110+
narrowed_k_scales = self.k_cache_scales.narrow(
111+
dim_to_slice, start_pos, seq_length
112+
)
113+
narrowed_k_zp = self.k_cache_zero_points.narrow(
114+
dim_to_slice, start_pos, seq_length
115+
)
116+
narrowed_k.copy_(quantized_k_val)
117+
narrowed_k_scales.copy_(k_scales)
118+
narrowed_k_zp.copy_(k_zero_points)
119+
# pyre-ignore: Incompatible parameter type [6]
120+
narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length)
121+
narrowed_v_scales = self.v_cache_scales.narrow(
122+
dim_to_slice, start_pos, seq_length
123+
)
124+
narrowed_v_zp = self.v_cache_zero_points.narrow(
125+
dim_to_slice, start_pos, seq_length
126+
)
127+
narrowed_v.copy_(quantized_v_val)
128+
narrowed_v_scales.copy_(v_scales)
129+
narrowed_v_zp.copy_(v_zero_points)
130+
else:
131+
if self.is_transposed:
132+
self.k_cache[:, :, input_pos] = quantized_k_val
133+
self.k_cache_scales[:, :, input_pos] = k_scales
134+
self.k_cache_zero_points[:, :, input_pos] = k_zero_points
135+
self.v_cache[:, :, input_pos] = quantized_v_val
136+
self.v_cache_scales[:, :, input_pos] = v_scales
137+
self.v_cache_zero_points[:, :, input_pos] = v_zero_points
138+
else:
139+
self.k_cache[:, input_pos] = quantized_k_val
140+
self.k_cache_scales[:, input_pos] = k_scales
141+
self.k_cache_zero_points[:, input_pos] = k_zero_points
142+
self.v_cache[:, input_pos] = quantized_v_val
143+
self.v_cache_scales[:, input_pos] = v_scales
144+
self.v_cache_zero_points[:, input_pos] = v_zero_points
145+
146+
k_out = torch.ops.quantized_decomposed.dequantize_per_token(
147+
self.k_cache,
148+
self.k_cache_scales,
149+
self.k_cache_zero_points,
150+
torch.iinfo(torch.int8).min,
151+
torch.iinfo(torch.int8).max,
152+
self.quantized_cache_dtype,
153+
self.cache_fp_type,
154+
)
155+
v_out = torch.ops.quantized_decomposed.dequantize_per_token(
156+
self.v_cache,
157+
self.v_cache_scales,
158+
self.v_cache_zero_points,
159+
torch.iinfo(torch.int8).min,
160+
torch.iinfo(torch.int8).max,
161+
self.quantized_cache_dtype,
162+
self.cache_fp_type,
163+
)
164+
return k_out, v_out
165+
166+
@classmethod
167+
def from_float(cls, kv_cache, cache_type: QuantizedCacheType):
168+
cache_shape = kv_cache.k_cache.shape
169+
if kv_cache.is_tranposed:
170+
max_batch_size, n_heads, max_seq_length, head_dim = cache_shape
171+
else:
172+
max_batch_size, max_seq_length, n_heads, head_dim = cache_shape
173+
return cls(
174+
max_batch_size,
175+
max_seq_length,
176+
n_heads,
177+
head_dim,
178+
cache_type,
179+
kv_cache.is_tranposed,
180+
kv_cache.enable_dynamic_shape,
181+
)
182+
183+
184+
def replace_kv_cache_with_quantized_kv_cache(module):
185+
logging.warning(
186+
"Replacing KVCache with QuantizedKVCache. This modifies the model in place."
187+
)
188+
for name, child in module.named_children():
189+
if isinstance(child, KVCache):
190+
setattr(
191+
module,
192+
name,
193+
QuantizedKVCache.from_float(child, QuantizedCacheType.AffineAsymmetric),
194+
)
195+
else:
196+
replace_kv_cache_with_quantized_kv_cache(child)
197+
return module

0 commit comments

Comments
 (0)