Skip to content

Commit 055bed5

Browse files
kimishpatelfacebook-github-bot
authored andcommitted
Add quantized kv cache to llama (#5664)
Summary: Pull Request resolved: #5664 This diff adds - quantized kv cache imlementation and apply corresponding source transforms - add support for quant/dequant per token in quantized kernels ghstack-source-id: 245703480 exported-using-ghexport //oss complaining of internal lint and build failure bypass-github-export-checks exported-using-ghexport Reviewed By: malfet, digantdesai Differential Revision: D62301844 fbshipit-source-id: d89679b5793b73de4e85cb1111414e41ec2e7faf
1 parent 8ddb846 commit 055bed5

File tree

15 files changed

+1106
-35
lines changed

15 files changed

+1106
-35
lines changed

examples/models/llama2/TARGETS

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ runtime.python_library(
8282
"source_transformation/apply_spin_quant_r1_r2.py",
8383
"source_transformation/prune_output.py",
8484
"source_transformation/quantize.py",
85+
"source_transformation/quantized_kv_cache.py",
8586
"source_transformation/rms_norm.py",
8687
"source_transformation/rope.py",
8788
"source_transformation/sdpa.py",
@@ -154,3 +155,27 @@ runtime.python_library(
154155
"//executorch/extension/pybindings:portable_lib",
155156
],
156157
)
158+
159+
runtime.python_library(
160+
name = "quantized_kv_cache",
161+
srcs = [
162+
"source_transformation/quantized_kv_cache.py",
163+
],
164+
_is_external_target = True,
165+
visibility = ["//executorch/..."],
166+
deps = [
167+
"//caffe2:torch",
168+
],
169+
)
170+
171+
runtime.python_test(
172+
name = "quantized_kv_cache_test",
173+
srcs = [
174+
"source_transformation/test_quantized_kv_cache.py",
175+
],
176+
deps = [
177+
":quantized_kv_cache",
178+
"//caffe2:torch",
179+
"//executorch/examples/models/llama2:llama_transformer",
180+
],
181+
)

examples/models/llama2/export_llama_lib.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,11 @@
5353
get_quant_embedding_transform,
5454
get_quant_weight_transform,
5555
)
56+
from .source_transformation.quantized_kv_cache import (
57+
replace_kv_cache_with_quantized_kv_cache,
58+
)
5659
from .source_transformation.rms_norm import replace_rms_norm_with_native_rms_norm
60+
5761
from .source_transformation.rope import materialze_broadcast_of_rope_freq_cis
5862
from .source_transformation.sdpa import (
5963
replace_causal_mask,
@@ -206,6 +210,12 @@ def build_args_parser() -> argparse.ArgumentParser:
206210
action="store_true",
207211
help="Whether or not to export a model using kv cache",
208212
)
213+
parser.add_argument(
214+
"--quantize_kv_cache",
215+
default=False,
216+
action="store_true",
217+
help="Whether or not to export a model using int8 per token quantized kv cache",
218+
)
209219
parser.add_argument(
210220
"--num_sharding",
211221
type=int,
@@ -460,7 +470,6 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
460470
461471
Returns a LLMEdgeManager prior to calling export_to_edge with quantizers
462472
"""
463-
464473
# load model from checkpoint and params.json
465474
checkpoint_path = canonical_path(args.checkpoint) if args.checkpoint else None
466475
checkpoint_dir = (
@@ -880,6 +889,12 @@ def _get_source_transforms( # noqa
880889
if args.use_sdpa_with_kv_cache:
881890
transforms.append(replace_sdpa_with_custom_op)
882891

892+
if args.quantize_kv_cache:
893+
assert (
894+
args.use_kv_cache and not args.use_sdpa_with_kv_cache
895+
), "quantize_kv_cache requires use_kv_cache=True and use_sdpa_with_kv_cache=False"
896+
transforms.append(replace_kv_cache_with_quantized_kv_cache)
897+
883898
if args.use_kv_cache:
884899
if args.qnn:
885900
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
from enum import Enum
9+
10+
import torch
11+
import torch.nn as nn
12+
from executorch.examples.models.llama2.llama_transformer import KVCache
13+
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
14+
15+
16+
"""
17+
Heavily "inspired" by AO's implementation of the same in torchao/_models/llama/model.py
18+
"""
19+
20+
21+
# Doesnt have to abide by affine quantizaiton laws
22+
# However, if we do implement quantized sdpa, then this might be handy
23+
class QuantizedCacheType(Enum):
24+
AffineSymmetric = 0
25+
AffineAsymmetric = 1
26+
AffineSymmetricGroupWise = 2
27+
AffineAsymmetricGroupWise = 3
28+
29+
30+
class QuantizedKVCache(nn.Module):
31+
def __init__(
32+
self,
33+
max_batch_size,
34+
max_seq_length,
35+
n_heads,
36+
head_dim,
37+
cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric,
38+
tranposed=False,
39+
enable_dynamic_shape=False,
40+
):
41+
super().__init__()
42+
if cache_type not in (
43+
QuantizedCacheType.AffineSymmetric,
44+
QuantizedCacheType.AffineAsymmetric,
45+
):
46+
47+
raise ValueError(
48+
f"Only affine symmetric and asymmetric cache types are supported: got {cache_type}"
49+
)
50+
# For now supporting int8 only
51+
self.quantized_cache_dtype = torch.int8
52+
self.cache_fp_type = torch.float32
53+
self.is_transposed = tranposed
54+
self.enable_dynamic_shape = enable_dynamic_shape
55+
if self.is_transposed:
56+
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
57+
scale_shape = (max_batch_size, n_heads, max_seq_length, 1)
58+
else:
59+
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
60+
scale_shape = (max_batch_size, max_seq_length, n_heads, 1)
61+
self.register_buffer(
62+
"k_cache", torch.zeros(cache_shape, dtype=self.quantized_cache_dtype)
63+
)
64+
self.register_buffer(
65+
"v_cache", torch.zeros(cache_shape, dtype=self.quantized_cache_dtype)
66+
)
67+
self.register_buffer(
68+
"k_cache_scales", torch.ones(scale_shape, dtype=torch.double)
69+
)
70+
self.register_buffer(
71+
"v_cache_scales", torch.ones(scale_shape, dtype=torch.double)
72+
)
73+
if cache_type == QuantizedCacheType.AffineAsymmetric:
74+
self.register_buffer(
75+
"k_cache_zero_points", torch.ones(scale_shape, dtype=torch.int64)
76+
)
77+
self.register_buffer(
78+
"v_cache_zero_points", torch.ones(scale_shape, dtype=torch.int64)
79+
)
80+
81+
def _quantize(self, value):
82+
scales, zero_points = (
83+
torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(
84+
value, self.quantized_cache_dtype
85+
)
86+
)
87+
quantized_value = torch.ops.quantized_decomposed.quantize_per_token(
88+
value,
89+
scales,
90+
zero_points,
91+
torch.iinfo(self.quantized_cache_dtype).min,
92+
torch.iinfo(self.quantized_cache_dtype).max,
93+
self.quantized_cache_dtype,
94+
)
95+
return quantized_value, scales, zero_points
96+
97+
def update(self, input_pos, k_val, v_val):
98+
# quantize current k_val and store it in the cache
99+
quantized_k_val, k_scales, k_zero_points = self._quantize(k_val)
100+
101+
quantized_v_val, v_scales, v_zero_points = self._quantize(v_val)
102+
103+
if self.enable_dynamic_shape:
104+
start_pos = input_pos[0].item()
105+
torch._check_is_size(start_pos)
106+
dim_to_slice = 2 if self.is_transposed else 1
107+
torch._check(start_pos < self.k_cache.size(dim_to_slice))
108+
seq_length = k_val.size(dim_to_slice)
109+
narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length)
110+
narrowed_k_scales = self.k_cache_scales.narrow(
111+
dim_to_slice, start_pos, seq_length
112+
)
113+
narrowed_k_zp = self.k_cache_zero_points.narrow(
114+
dim_to_slice, start_pos, seq_length
115+
)
116+
narrowed_k.copy_(quantized_k_val)
117+
narrowed_k_scales.copy_(k_scales)
118+
narrowed_k_zp.copy_(k_zero_points)
119+
narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length)
120+
narrowed_v_scales = self.v_cache_scales.narrow(
121+
dim_to_slice, start_pos, seq_length
122+
)
123+
narrowed_v_zp = self.v_cache_zero_points.narrow(
124+
dim_to_slice, start_pos, seq_length
125+
)
126+
narrowed_v.copy_(quantized_v_val)
127+
narrowed_v_scales.copy_(v_scales)
128+
narrowed_v_zp.copy_(v_zero_points)
129+
else:
130+
if self.is_transposed:
131+
self.k_cache[:, :, input_pos] = quantized_k_val
132+
self.k_cache_scales[:, :, input_pos] = k_scales
133+
self.k_cache_zero_points[:, :, input_pos] = k_zero_points
134+
self.v_cache[:, :, input_pos] = quantized_v_val
135+
self.v_cache_scales[:, :, input_pos] = v_scales
136+
self.v_cache_zero_points[:, :, input_pos] = v_zero_points
137+
else:
138+
self.k_cache[:, input_pos] = quantized_k_val
139+
self.k_cache_scales[:, input_pos] = k_scales
140+
self.k_cache_zero_points[:, input_pos] = k_zero_points
141+
self.v_cache[:, input_pos] = quantized_v_val
142+
self.v_cache_scales[:, input_pos] = v_scales
143+
self.v_cache_zero_points[:, input_pos] = v_zero_points
144+
145+
k_out = torch.ops.quantized_decomposed.dequantize_per_token(
146+
self.k_cache,
147+
self.k_cache_scales,
148+
self.k_cache_zero_points,
149+
torch.iinfo(self.quantized_cache_dtype).min,
150+
torch.iinfo(self.quantized_cache_dtype).max,
151+
self.quantized_cache_dtype,
152+
self.cache_fp_type,
153+
)
154+
v_out = torch.ops.quantized_decomposed.dequantize_per_token(
155+
self.v_cache,
156+
self.v_cache_scales,
157+
self.v_cache_zero_points,
158+
torch.iinfo(self.quantized_cache_dtype).min,
159+
torch.iinfo(self.quantized_cache_dtype).max,
160+
self.quantized_cache_dtype,
161+
self.cache_fp_type,
162+
)
163+
return k_out, v_out
164+
165+
@classmethod
166+
def from_float(cls, kv_cache, cache_type: QuantizedCacheType):
167+
cache_shape = kv_cache.k_cache.shape
168+
if kv_cache.is_tranposed:
169+
max_batch_size, n_heads, max_seq_length, head_dim = cache_shape
170+
else:
171+
max_batch_size, max_seq_length, n_heads, head_dim = cache_shape
172+
return cls(
173+
max_batch_size,
174+
max_seq_length,
175+
n_heads,
176+
head_dim,
177+
cache_type,
178+
kv_cache.is_tranposed,
179+
kv_cache.enable_dynamic_shape,
180+
)
181+
182+
183+
def replace_kv_cache_with_quantized_kv_cache(module):
184+
logging.warning(
185+
"Replacing KVCache with QuantizedKVCache. This modifies the model in place."
186+
)
187+
for name, child in module.named_children():
188+
if isinstance(child, KVCache):
189+
setattr(
190+
module,
191+
name,
192+
QuantizedKVCache.from_float(child, QuantizedCacheType.AffineAsymmetric),
193+
)
194+
else:
195+
replace_kv_cache_with_quantized_kv_cache(child)
196+
return module

0 commit comments

Comments
 (0)