Skip to content

Commit d28f61a

Browse files
committed
[Executorch] Add quantized kv cache to oss ci
Pull Request resolved: #6997 Fixes to make sure quantized kv cache works in oss ghstack-source-id: 254774011 @exported-using-ghexport Differential Revision: [D66269487](https://our.internmc.facebook.com/intern/diff/D66269487/)
1 parent d4874e8 commit d28f61a

File tree

6 files changed

+27
-3
lines changed

6 files changed

+27
-3
lines changed

.ci/scripts/test_llama.sh

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ else
7070
COREML=OFF
7171
fi
7272

73+
if [[ "${MODE}" =~ .*quantize_kv.* ]]; then
74+
QUANTIZE_KV_CACHE=ON
75+
else
76+
QUANTIZE_KV_CACHE=OFF
77+
fi
78+
7379
echo "COREML option ${COREML}"
7480

7581
if [[ "${MODE}" =~ .*qnn.* ]]; then
@@ -205,6 +211,9 @@ fi
205211
if [[ "${QNN}" == "ON" ]]; then
206212
EXPORT_ARGS="${EXPORT_ARGS} -kv -v --qnn --disable_dynamic_shape"
207213
fi
214+
if [[ "${QUANTIZE_KV_CACHE}" == "ON" ]]; then
215+
EXPORT_ARGS="${EXPORT_ARGS} --quantize_kv_cache"
216+
fi
208217
# Add dynamically linked library location
209218
$PYTHON_EXECUTABLE -m examples.models.llama.export_llama ${EXPORT_ARGS}
210219

.github/workflows/pull.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ jobs:
8686
strategy:
8787
matrix:
8888
dtype: [fp32]
89-
mode: [portable, xnnpack+custom, xnnpack+custom+qe]
89+
mode: [portable, xnnpack+custom, xnnpack+custom+qe,xnnpack+custom+quantize_kv,xnnpack+quantize_kv]
9090
include:
9191
- dtype: bf16
9292
mode: portable

.github/workflows/trunk.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ jobs:
225225
strategy:
226226
matrix:
227227
dtype: [fp32]
228-
mode: [portable, xnnpack+kv+custom, mps, coreml]
228+
mode: [portable, xnnpack+kv+custom, mps, coreml, xnnpack+custom+quantize_kv]
229229
include:
230230
- dtype: bf16
231231
mode: portable

examples/models/llama/source_transformation/quantized_kv_cache.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import logging
88
from enum import Enum
99

10+
import executorch.extension.llm.custom_ops # noqa: F401
11+
1012
import torch
1113
import torch.nn as nn
1214
from executorch.examples.models.llama.llama_transformer import KVCache

examples/models/llama/source_transformation/sdpa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def forward(
5656

5757
k_cache = self.kv_cache.k_cache
5858
v_cache = self.kv_cache.v_cache
59-
if isinstance(self.kv_cache, QuantizedKVCache):
59+
if hasattr(self.kv_cache, "quantized_cache_dtype"):
6060
# updated quantize cache, scale and zero points
6161
# returns dequantized kv cache
6262
# Not most optimal. Optimizations to follow next

exir/passes/_quant_patterns_and_replacements.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,19 @@ def embedding_byte_dtype_out_meta(
192192
"int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)",
193193
)
194194

195+
# TODO: move these registrations to pytorch core
196+
quantized_decomposed_lib.define(
197+
"quantize_per_token.out(Tensor input, Tensor scales, Tensor zero_points, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)",
198+
)
199+
200+
quantized_decomposed_lib.define(
201+
"dequantize_per_token.out(Tensor input, Tensor scales, Tensor zero_points, int quant_min, int quant_max, ScalarType dtype, ScalarType output_dtype, *, Tensor(a!) out) -> Tensor(a!)",
202+
)
203+
204+
quantized_decomposed_lib.define(
205+
"choose_qparams_per_token_asymmetric.out(Tensor input, ScalarType dtype, *, Tensor(a!) scale_out, Tensor(b!) zero_point_out) -> (Tensor(a!), Tensor(b!))",
206+
)
207+
195208

196209
@impl(quantized_decomposed_lib, "embedding_2bit", "CompositeExplicitAutograd")
197210
def embedding_2bit(

0 commit comments

Comments
 (0)