Skip to content

Commit 6fec56f

Browse files
committed
Update on "Add update_quantized_cache op"
Why? - ton of copies due to functionalization - mutable buffer support without such custom inplace ops will results in giant copies at the end - Making inplace ops work will likely take longer and not clear safe path Differential Revision: [D62301838](https://our.internmc.facebook.com/intern/diff/D62301838/) [ghstack-poisoned]
2 parents ede4406 + d1dcdc6 commit 6fec56f

File tree

4 files changed

+46
-84
lines changed

4 files changed

+46
-84
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 7 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def build_args_parser() -> argparse.ArgumentParser:
214214
"--quantize_kv_cache",
215215
default=False,
216216
action="store_true",
217-
help="Whether or not to export a model using quantized kv cache",
217+
help="Whether or not to export a model using int8 per token quantized kv cache",
218218
)
219219
parser.add_argument(
220220
"--num_sharding",
@@ -455,41 +455,6 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
455455
else:
456456
dtype_override = None
457457

458-
# source transforms
459-
transforms = []
460-
if args.quantization_mode:
461-
modelname = f"{modelname}_q"
462-
transforms.append(
463-
get_quant_weight_transform(args, dtype_override, verbose_export())
464-
)
465-
466-
if args.embedding_quantize:
467-
modelname = f"{modelname}_e"
468-
transforms.append(get_quant_embedding_transform(args))
469-
470-
if args.expand_rope_table:
471-
transforms.append(materialze_broadcast_of_rope_freq_cis)
472-
473-
if args.use_sdpa_with_kv_cache:
474-
transforms.append(replace_sdpa_with_custom_op)
475-
476-
if args.quantize_kv_cache:
477-
assert (
478-
args.use_kv_cache and not args.use_sdpa_with_kv_cache
479-
), "quantize_kv_cache requires use_kv_cache=True and use_sdpa_with_kv_cache=False"
480-
transforms.append(replace_kv_cache_with_quantized_kv_cache)
481-
482-
if args.use_kv_cache:
483-
if args.qnn:
484-
transforms.append(replace_kv_cache_with_simple_kv_cache)
485-
transforms.append(replace_sdpa_with_flex_sdpa)
486-
transforms.append(replace_causal_mask)
487-
488-
elif args.coreml or args.mps:
489-
# Currently qnn/coreml/mps doesn't support sdpa op, use the simpler decomposition
490-
# to get free perf gain.
491-
transforms.append(replace_sdpa_with_simple_sdpa)
492-
transforms.append(replace_causal_mask)
493458
return (
494459
_load_llama_model(
495460
modelname=modelname,
@@ -850,6 +815,12 @@ def _get_source_transforms( # noqa
850815
if args.use_sdpa_with_kv_cache:
851816
transforms.append(replace_sdpa_with_custom_op)
852817

818+
if args.quantize_kv_cache:
819+
assert (
820+
args.use_kv_cache and not args.use_sdpa_with_kv_cache
821+
), "quantize_kv_cache requires use_kv_cache=True and use_sdpa_with_kv_cache=False"
822+
transforms.append(replace_kv_cache_with_quantized_kv_cache)
823+
853824
if args.use_kv_cache:
854825
if args.qnn:
855826
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`

examples/models/llama2/source_transformation/quantized_kv_cache.py

Lines changed: 29 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
class QuantizedCacheType(Enum):
2424
AffineSymmetric = 0
2525
AffineAsymmetric = 1
26-
AffineSymmetricGroupWise = 1
27-
AffineAsymmetricGroupWise = 2
26+
AffineSymmetricGroupWise = 2
27+
AffineAsymmetricGroupWise = 3
2828

2929

3030
class QuantizedKVCache(nn.Module):
@@ -58,8 +58,12 @@ def __init__(
5858
else:
5959
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
6060
scale_shape = (max_batch_size, max_seq_length, n_heads, 1)
61-
self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=torch.int8))
62-
self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=torch.int8))
61+
self.register_buffer(
62+
"k_cache", torch.zeros(cache_shape, dtype=self.quantized_cache_dtype)
63+
)
64+
self.register_buffer(
65+
"v_cache", torch.zeros(cache_shape, dtype=self.quantized_cache_dtype)
66+
)
6367
self.register_buffer(
6468
"k_cache_scales", torch.ones(scale_shape, dtype=torch.double)
6569
)
@@ -74,43 +78,32 @@ def __init__(
7478
"v_cache_zero_points", torch.ones(scale_shape, dtype=torch.int64)
7579
)
7680

77-
def update(self, input_pos, k_val, v_val):
78-
# quantize current k_val and store it in the cache
79-
k_scales, k_zero_points = (
81+
def _quantize(self, value):
82+
scales, zero_points = (
8083
torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(
81-
k_val, torch.int8 # no other value is supported by this op anyway
84+
value, self.quantized_cache_dtype
8285
)
8386
)
84-
quantized_k_val = torch.ops.quantized_decomposed.quantize_per_token(
85-
k_val,
86-
k_scales,
87-
k_zero_points,
88-
torch.iinfo(torch.int8).min,
89-
torch.iinfo(torch.int8).max,
90-
torch.int8,
87+
quantized_value = torch.ops.quantized_decomposed.quantize_per_token(
88+
value,
89+
scales,
90+
zero_points,
91+
torch.iinfo(self.quantized_cache_dtype).min,
92+
torch.iinfo(self.quantized_cache_dtype).max,
93+
self.quantized_cache_dtype,
9194
)
95+
return quantized_value, scales, zero_points
9296

93-
v_scales, v_zero_points = (
94-
torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric(
95-
v_val, torch.int8
96-
)
97-
)
98-
quantized_v_val = torch.ops.quantized_decomposed.quantize_per_token(
99-
v_val,
100-
v_scales,
101-
v_zero_points,
102-
torch.iinfo(torch.int8).min,
103-
torch.iinfo(torch.int8).max,
104-
torch.int8,
105-
)
97+
def update(self, input_pos, k_val, v_val):
98+
# quantize current k_val and store it in the cache
99+
quantized_k_val, k_scales, k_zero_points = self._quantize(k_val)
100+
101+
quantized_v_val, v_scales, v_zero_points = self._quantize(v_val)
106102

107103
if self.enable_dynamic_shape:
108104
start_pos = input_pos[0].item()
109105
torch._check_is_size(start_pos)
110-
if self.is_transposed:
111-
dim_to_slice = 2
112-
else:
113-
dim_to_slice = 1
106+
dim_to_slice = 2 if self.is_transposed else 1
114107
torch._check(start_pos < self.k_cache.size(dim_to_slice))
115108
seq_length = k_val.size(dim_to_slice)
116109
narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length)
@@ -154,17 +147,17 @@ def update(self, input_pos, k_val, v_val):
154147
self.k_cache,
155148
self.k_cache_scales,
156149
self.k_cache_zero_points,
157-
torch.iinfo(torch.int8).min,
158-
torch.iinfo(torch.int8).max,
150+
torch.iinfo(self.quantized_cache_dtype).min,
151+
torch.iinfo(self.quantized_cache_dtype).max,
159152
self.quantized_cache_dtype,
160153
self.cache_fp_type,
161154
)
162155
v_out = torch.ops.quantized_decomposed.dequantize_per_token(
163156
self.v_cache,
164157
self.v_cache_scales,
165158
self.v_cache_zero_points,
166-
torch.iinfo(torch.int8).min,
167-
torch.iinfo(torch.int8).max,
159+
torch.iinfo(self.quantized_cache_dtype).min,
160+
torch.iinfo(self.quantized_cache_dtype).max,
168161
self.quantized_cache_dtype,
169162
self.cache_fp_type,
170163
)

kernels/quantized/cpu/op_dequantize.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -394,17 +394,19 @@ Tensor& dequantize_per_token_out(
394394
for (size_t i = 0; i < input.dim() - 1; i++) {
395395
num_channels *= input.size(i);
396396
}
397-
// This unfortunate change is needed because we compile op_quantize for aten
398-
// mode as well
397+
// This unfortunate change is needed because we compile op_quantize for aten
398+
// mode as well
399+
std::array<exec_aten::SizesType, 2> input_sizes;
400+
input_sizes[0] = static_cast<exec_aten::SizesType>(num_channels);
401+
input_sizes[1] =
402+
static_cast<exec_aten::SizesType>(input.size(input.dim() - 1));
399403
#ifdef USE_ATEN_LIB
400-
const std::array<int64_t, 2> sizes = {{num_channels, input.dim() - 1}};
401404
Tensor reshaped_input = at::from_blob(
402-
input.mutable_data_ptr(), sizes, at::TensorOptions(input.scalar_type()));
405+
input.mutable_data_ptr(),
406+
input_sizes,
407+
at::TensorOptions(input.scalar_type()));
403408
#else
404409
std::array<exec_aten::DimOrderType, 2> input_dim_order{0, 1};
405-
std::array<exec_aten::SizesType, 2> input_sizes;
406-
input_sizes[0] = num_channels;
407-
input_sizes[1] = input.size(input.dim() - 1);
408410
std::array<exec_aten::StridesType, 2> input_strides;
409411
dim_order_to_stride_nocheck(
410412
input_sizes.data(), input_dim_order.data(), 2, input_strides.data());
@@ -428,7 +430,7 @@ Tensor& dequantize_per_token_out(
428430
reshaped_input,
429431
scale,
430432
zero_points,
431-
0,
433+
0, /* axis */
432434
quant_min,
433435
quant_max,
434436
dtype,

kernels/quantized/test/test_quant_dequant_per_token.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,10 @@
99
import unittest
1010

1111
import torch
12-
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
1312

1413

1514
class QuantizePerTokenTest(unittest.TestCase):
1615

17-
def setUp(self):
18-
pass
19-
2016
def test_quantize_per_token(self):
2117
input_tensor = torch.tensor(
2218
[[-0.5, 0.3, 1.2], [0.1, -0.8, 2.1], [-5, 1, 2]], dtype=torch.float32

0 commit comments

Comments
 (0)