Skip to content

Commit 25784fa

Browse files
committed
[ExecuTorch] Add quantized kv cache to llama
This diff adds - quantized kv cache imlementation and apply corresponding source transforms - add support for quant/dequant per token in quantized kernels Differential Revision: [D62301844](https://our.internmc.facebook.com/intern/diff/D62301844/) ghstack-source-id: 244449199 Pull Request resolved: pytorch/executorch#5598
1 parent 5028fc5 commit 25784fa

File tree

15 files changed

+1144
-35
lines changed

15 files changed

+1144
-35
lines changed

examples/models/llama2/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ runtime.python_library(
7373
"source_transformation/apply_spin_quant_r1_r2.py",
7474
"source_transformation/prune_output.py",
7575
"source_transformation/quantize.py",
76+
"source_transformation/quantized_kv_cache.py",
7677
"source_transformation/rms_norm.py",
7778
"source_transformation/rope.py",
7879
"source_transformation/sdpa.py",

examples/models/llama2/export_llama_lib.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,11 @@
5353
get_quant_embedding_transform,
5454
get_quant_weight_transform,
5555
)
56+
from .source_transformation.quantized_kv_cache import (
57+
replace_kv_cache_with_quantized_kv_cache,
58+
)
5659
from .source_transformation.rms_norm import replace_rms_norm_with_native_rms_norm
60+
5761
from .source_transformation.rope import materialze_broadcast_of_rope_freq_cis
5862
from .source_transformation.sdpa import (
5963
replace_causal_mask,
@@ -206,6 +210,12 @@ def build_args_parser() -> argparse.ArgumentParser:
206210
action="store_true",
207211
help="Whether or not to export a model using kv cache",
208212
)
213+
parser.add_argument(
214+
"--quantize_kv_cache",
215+
default=False,
216+
action="store_true",
217+
help="Whether or not to export a model using quantized kv cache",
218+
)
209219
parser.add_argument(
210220
"--num_sharding",
211221
type=int,
@@ -428,7 +438,6 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
428438
429439
Returns a LLMEdgeManager prior to calling export_to_edge with quantizers
430440
"""
431-
432441
# load model from checkpoint and params.json
433442
checkpoint_path = canonical_path(args.checkpoint) if args.checkpoint else None
434443
checkpoint_dir = (
@@ -446,6 +455,41 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
446455
else:
447456
dtype_override = None
448457

458+
# source transforms
459+
transforms = []
460+
if args.quantization_mode:
461+
modelname = f"{modelname}_q"
462+
transforms.append(
463+
get_quant_weight_transform(args, dtype_override, verbose_export())
464+
)
465+
466+
if args.embedding_quantize:
467+
modelname = f"{modelname}_e"
468+
transforms.append(get_quant_embedding_transform(args))
469+
470+
if args.expand_rope_table:
471+
transforms.append(materialze_broadcast_of_rope_freq_cis)
472+
473+
if args.use_sdpa_with_kv_cache:
474+
transforms.append(replace_sdpa_with_custom_op)
475+
476+
if args.quantize_kv_cache:
477+
assert (
478+
args.use_kv_cache and not args.use_sdpa_with_kv_cache
479+
), "quantize_kv_cache requires use_kv_cache=True and use_sdpa_with_kv_cache=False"
480+
transforms.append(replace_kv_cache_with_quantized_kv_cache)
481+
482+
if args.use_kv_cache:
483+
if args.qnn:
484+
transforms.append(replace_kv_cache_with_simple_kv_cache)
485+
transforms.append(replace_sdpa_with_flex_sdpa)
486+
transforms.append(replace_causal_mask)
487+
488+
elif args.coreml or args.mps:
489+
# Currently qnn/coreml/mps doesn't support sdpa op, use the simpler decomposition
490+
# to get free perf gain.
491+
transforms.append(replace_sdpa_with_simple_sdpa)
492+
transforms.append(replace_causal_mask)
449493
return (
450494
_load_llama_model(
451495
modelname=modelname,
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
oncall("executorch")
4+
5+
runtime.python_library(
6+
name = "quantized_kv_cache",
7+
srcs = [
8+
"quantized_kv_cache.py",
9+
],
10+
_is_external_target = True,
11+
base_module = "executorch.examples.models.llama2.source_transformation",
12+
visibility = ["//executorch/..."],
13+
deps = [
14+
"//caffe2:torch",
15+
],
16+
)
17+
18+
runtime.python_test(
19+
name = "quantized_kv_cache_test",
20+
srcs = [
21+
"test_quantized_kv_cache.py",
22+
],
23+
deps = [
24+
":quantized_kv_cache",
25+
"//caffe2:torch",
26+
"//executorch/examples/models/llama2:llama_transformer",
27+
],
28+
)
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
from enum import Enum
9+
10+
import torch
11+
import torch.nn as nn
12+
from executorch.examples.models.llama2.llama_transformer import KVCache
13+
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
14+
15+
16+
"""
17+
Heavily "inspired" by AO's implementation of the same in torchao/_models/llama/model.py
18+
"""
19+
20+
21+
# Doesnt have to abide by affine quantizaiton laws
22+
# However, if we do implement quantized sdpa, then this might be handy
23+
class QuantizedCacheType(Enum):
24+
AffineSymmetric = 0
25+
AffineAsymmetric = 1
26+
AffineSymmetricGroupWise = 1
27+
AffineAsymmetricGroupWise = 2
28+
29+
30+
class QuantizedKVCache(nn.Module):
31+
def __init__(
32+
self,
33+
max_batch_size,
34+
max_seq_length,
35+
n_heads,
36+
head_dim,
37+
cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric,
38+
tranposed=False,
39+
enable_dynamic_shape=False,
40+
):
41+
super().__init__()
42+
if cache_type not in (
43+
QuantizedCacheType.AffineSymmetric,
44+
QuantizedCacheType.AffineAsymmetric,
45+
):
46+
47+
raise ValueError(
48+
f"Only affine symmetric and asymmetric cache types are supported: got {cache_type}"
49+
)
50+
# For now supporting int8 only
51+
self.quantized_cache_dtype = torch.int8
52+
self.cache_fp_type = torch.float32
53+
self.is_transposed = tranposed
54+
self.enable_dynamic_shape = enable_dynamic_shape
55+
if self.is_transposed:
56+
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
57+
scale_shape = (max_batch_size, n_heads, max_seq_length, 1)
58+
else:
59+
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
60+
scale_shape = (max_batch_size, max_seq_length, n_heads, 1)
61+
self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=torch.int8))
62+
self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=torch.int8))
63+
self.register_buffer(
64+
"k_cache_scales", torch.ones(scale_shape, dtype=torch.double)
65+
)
66+
self.register_buffer(
67+
"v_cache_scales", torch.ones(scale_shape, dtype=torch.double)
68+
)
69+
if cache_type == QuantizedCacheType.AffineAsymmetric:
70+
self.register_buffer(
71+
"k_cache_zero_points", torch.ones(scale_shape, dtype=torch.int64)
72+
)
73+
self.register_buffer(
74+
"v_cache_zero_points", torch.ones(scale_shape, dtype=torch.int64)
75+
)
76+
77+
def update(self, input_pos, k_val, v_val):
78+
# quantize current k_val and store it in the cache
79+
k_scales, k_zero_points = (
80+
torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(
81+
k_val, torch.int8 # no other value is supported by this op anyway
82+
)
83+
)
84+
quantized_k_val = torch.ops.quantized_decomposed.quantize_per_token(
85+
k_val,
86+
k_scales,
87+
k_zero_points,
88+
torch.iinfo(torch.int8).min,
89+
torch.iinfo(torch.int8).max,
90+
torch.int8,
91+
)
92+
93+
v_scales, v_zero_points = (
94+
torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric(
95+
v_val, torch.int8
96+
)
97+
)
98+
quantized_v_val = torch.ops.quantized_decomposed.quantize_per_token(
99+
v_val,
100+
v_scales,
101+
v_zero_points,
102+
torch.iinfo(torch.int8).min,
103+
torch.iinfo(torch.int8).max,
104+
torch.int8,
105+
)
106+
107+
if self.enable_dynamic_shape:
108+
start_pos = input_pos[0].item()
109+
torch._check_is_size(start_pos)
110+
if self.is_transposed:
111+
dim_to_slice = 2
112+
else:
113+
dim_to_slice = 1
114+
torch._check(start_pos < self.k_cache.size(dim_to_slice))
115+
seq_length = k_val.size(dim_to_slice)
116+
narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length)
117+
narrowed_k_scales = self.k_cache_scales.narrow(
118+
dim_to_slice, start_pos, seq_length
119+
)
120+
narrowed_k_zp = self.k_cache_zero_points.narrow(
121+
dim_to_slice, start_pos, seq_length
122+
)
123+
narrowed_k.copy_(quantized_k_val)
124+
narrowed_k_scales.copy_(k_scales)
125+
narrowed_k_zp.copy_(k_zero_points)
126+
# pyre-ignore: Incompatible parameter type [6]
127+
narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length)
128+
narrowed_v_scales = self.v_cache_scales.narrow(
129+
dim_to_slice, start_pos, seq_length
130+
)
131+
narrowed_v_zp = self.v_cache_zero_points.narrow(
132+
dim_to_slice, start_pos, seq_length
133+
)
134+
narrowed_v.copy_(quantized_v_val)
135+
narrowed_v_scales.copy_(v_scales)
136+
narrowed_v_zp.copy_(v_zero_points)
137+
else:
138+
if self.is_transposed:
139+
self.k_cache[:, :, input_pos] = quantized_k_val
140+
self.k_cache_scales[:, :, input_pos] = k_scales
141+
self.k_cache_zero_points[:, :, input_pos] = k_zero_points
142+
self.v_cache[:, :, input_pos] = quantized_v_val
143+
self.v_cache_scales[:, :, input_pos] = v_scales
144+
self.v_cache_zero_points[:, :, input_pos] = v_zero_points
145+
else:
146+
self.k_cache[:, input_pos] = quantized_k_val
147+
self.k_cache_scales[:, input_pos] = k_scales
148+
self.k_cache_zero_points[:, input_pos] = k_zero_points
149+
self.v_cache[:, input_pos] = quantized_v_val
150+
self.v_cache_scales[:, input_pos] = v_scales
151+
self.v_cache_zero_points[:, input_pos] = v_zero_points
152+
153+
k_out = torch.ops.quantized_decomposed.dequantize_per_token(
154+
self.k_cache,
155+
self.k_cache_scales,
156+
self.k_cache_zero_points,
157+
torch.iinfo(torch.int8).min,
158+
torch.iinfo(torch.int8).max,
159+
self.quantized_cache_dtype,
160+
self.cache_fp_type,
161+
)
162+
v_out = torch.ops.quantized_decomposed.dequantize_per_token(
163+
self.v_cache,
164+
self.v_cache_scales,
165+
self.v_cache_zero_points,
166+
torch.iinfo(torch.int8).min,
167+
torch.iinfo(torch.int8).max,
168+
self.quantized_cache_dtype,
169+
self.cache_fp_type,
170+
)
171+
return k_out, v_out
172+
173+
@classmethod
174+
def from_float(cls, kv_cache, cache_type: QuantizedCacheType):
175+
cache_shape = kv_cache.k_cache.shape
176+
if kv_cache.is_tranposed:
177+
max_batch_size, n_heads, max_seq_length, head_dim = cache_shape
178+
else:
179+
max_batch_size, max_seq_length, n_heads, head_dim = cache_shape
180+
return cls(
181+
max_batch_size,
182+
max_seq_length,
183+
n_heads,
184+
head_dim,
185+
cache_type,
186+
kv_cache.is_tranposed,
187+
kv_cache.enable_dynamic_shape,
188+
)
189+
190+
191+
def replace_kv_cache_with_quantized_kv_cache(module):
192+
logging.warning(
193+
"Replacing KVCache with QuantizedKVCache. This modifies the model in place."
194+
)
195+
for name, child in module.named_children():
196+
if isinstance(child, KVCache):
197+
setattr(
198+
module,
199+
name,
200+
QuantizedKVCache.from_float(child, QuantizedCacheType.AffineAsymmetric),
201+
)
202+
else:
203+
replace_kv_cache_with_quantized_kv_cache(child)
204+
return module

0 commit comments

Comments
 (0)