Skip to content

Commit 7a5589f

Browse files
committed
[Executorch] Add quantized kv cache to oss ci
Pull Request resolved: #6997 Fixes to make sure quantized kv cache works in oss ghstack-source-id: 254902439 @exported-using-ghexport Differential Revision: [D66269487](https://our.internmc.facebook.com/intern/diff/D66269487/)
1 parent 0285d2b commit 7a5589f

File tree

6 files changed

+26
-4
lines changed

6 files changed

+26
-4
lines changed

.ci/scripts/test_llama.sh

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,12 @@ else
100100
COREML=OFF
101101
fi
102102

103+
if [[ "${MODE}" =~ .*quantize_kv.* ]]; then
104+
QUANTIZE_KV_CACHE=ON
105+
else
106+
QUANTIZE_KV_CACHE=OFF
107+
fi
108+
103109
echo "COREML option ${COREML}"
104110

105111
if [[ "${MODE}" =~ .*qnn.* ]]; then
@@ -235,6 +241,9 @@ fi
235241
if [[ "${QNN}" == "ON" ]]; then
236242
EXPORT_ARGS="${EXPORT_ARGS} -kv -v --qnn --disable_dynamic_shape"
237243
fi
244+
if [[ "${QUANTIZE_KV_CACHE}" == "ON" ]]; then
245+
EXPORT_ARGS="${EXPORT_ARGS} --quantize_kv_cache"
246+
fi
238247
# Add dynamically linked library location
239248
$PYTHON_EXECUTABLE -m examples.models.llama.export_llama ${EXPORT_ARGS}
240249

.github/workflows/pull.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ jobs:
8686
strategy:
8787
matrix:
8888
dtype: [fp32]
89-
mode: [portable, xnnpack+custom, xnnpack+custom+qe]
89+
mode: [portable, xnnpack+custom, xnnpack+custom+qe,xnnpack+custom+quantize_kv,xnnpack+quantize_kv]
9090
include:
9191
- dtype: bf16
9292
mode: portable

.github/workflows/trunk.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ jobs:
225225
strategy:
226226
matrix:
227227
dtype: [fp32]
228-
mode: [portable, xnnpack+kv+custom, mps, coreml]
228+
mode: [portable, xnnpack+kv+custom, mps, coreml, xnnpack+custom+quantize_kv]
229229
include:
230230
- dtype: bf16
231231
mode: portable

examples/models/llama/source_transformation/quantized_kv_cache.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,29 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from pathlib import Path
78
import logging
89
from enum import Enum
910

11+
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401
12+
1013
import torch
1114
import torch.nn as nn
1215
from executorch.examples.models.llama.llama_transformer import KVCache
1316
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
1417

1518

19+
try:
20+
op = torch.ops.quantized_decomposed.quantize_per_token
21+
assert op is not None
22+
except:
23+
libs = list(Path(__file__).parent.resolve().glob("libquantized_ops_aot_lib.*"))
24+
assert len(libs) == 1, f"Expected 1 library but got {len(libs)}"
25+
logging.info(f"Loading custom ops library: {libs[0]}")
26+
torch.ops.load_library(libs[0])
27+
op = torch.ops.quantized_decomposed.quantize_per_token
28+
assert op is not None
29+
1630
"""
1731
Heavily "inspired" by AO's implementation of the same in torchao/_models/llama/model.py
1832
"""

examples/models/llama/source_transformation/sdpa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def forward(
5656

5757
k_cache = self.kv_cache.k_cache
5858
v_cache = self.kv_cache.v_cache
59-
if isinstance(self.kv_cache, QuantizedKVCache):
59+
if hasattr(self.kv_cache, "quantized_cache_dtype"):
6060
# updated quantize cache, scale and zero points
6161
# returns dequantized kv cache
6262
# Not most optimal. Optimizations to follow next

extension/llm/custom_ops/custom_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
from torch.library import impl
1919

20-
# TODO rename this file to custom_ops_meta_registration.py
2120
try:
2221
op = torch.ops.llama.sdpa_with_kv_cache.default
2322
assert op is not None

0 commit comments

Comments
 (0)