Skip to content

Commit 73c277a

Browse files
committed
Update on "[Executorch] Add quantized kv cache to oss ci"
Fixes to make sure quantized kv cache works in oss Differential Revision: [D66269487](https://our.internmc.facebook.com/intern/diff/D66269487/) [ghstack-poisoned]
2 parents a943088 + ffac7af commit 73c277a

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

examples/models/llama/source_transformation/quantized_kv_cache.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,34 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from pathlib import Path
87
import logging
98
from enum import Enum
109

11-
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401
12-
1310
import torch
1411
import torch.nn as nn
1512
from executorch.examples.models.llama.llama_transformer import KVCache
13+
14+
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401
1615
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
1716

1817

1918
try:
20-
op = torch.ops.quantized_decomposed.quantize_per_token
19+
op = torch.ops.quantized_decomposed.quantize_per_token.out
2120
assert op is not None
2221
except:
23-
libs = list(Path(__file__).parent.resolve().glob("libquantized_ops_aot_lib.*"))
22+
import executorch
23+
import glob
24+
25+
executorch_package_path = executorch.__path__[0]
26+
libs = list(
27+
glob.glob(
28+
f"{executorch_package_path}/**/libquantized_ops_aot_lib.*", recursive=True
29+
)
30+
)
2431
assert len(libs) == 1, f"Expected 1 library but got {len(libs)}"
2532
logging.info(f"Loading custom ops library: {libs[0]}")
2633
torch.ops.load_library(libs[0])
27-
op = torch.ops.quantized_decomposed.quantize_per_token
34+
op = torch.ops.quantized_decomposed.quantize_per_token.out
2835
assert op is not None
2936

3037
"""
@@ -204,7 +211,6 @@ def update(self, input_pos, k_val, v_val):
204211
seq_length = k_val.size(dim_to_slice)
205212
narrowed_k = k_out.narrow(dim_to_slice, start_pos, seq_length)
206213
narrowed_k.copy_(k_val)
207-
# pyre-ignore: Incompatible parameter type [6]
208214
narrowed_v = v_out.narrow(dim_to_slice, start_pos, seq_length)
209215
narrowed_v.copy_(v_val)
210216
else:

kernels/quantized/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,17 @@ if(NOT CMAKE_GENERATOR STREQUAL "Xcode"
6060
set(_quantized_aot_ops
6161
"quantized_decomposed::add.out"
6262
"quantized_decomposed::choose_qparams.Tensor_out"
63+
"quantized_decomposed::choose_qparams_per_token_asymmetric.out"
6364
"quantized_decomposed::dequantize_per_channel.out"
6465
"quantized_decomposed::dequantize_per_tensor.out"
6566
"quantized_decomposed::dequantize_per_tensor.Tensor_out"
67+
"quantized_decomposed::dequantize_per_token.out"
6668
"quantized_decomposed::mixed_linear.out"
6769
"quantized_decomposed::mixed_mm.out"
6870
"quantized_decomposed::quantize_per_channel.out"
6971
"quantized_decomposed::quantize_per_tensor.out"
7072
"quantized_decomposed::quantize_per_tensor.Tensor_out"
73+
"quantized_decomposed::quantize_per_token.out"
7174
)
7275
gen_selected_ops(
7376
LIB_NAME "quantized_ops_aot_lib" ROOT_OPS ${_quantized_aot_ops}

0 commit comments

Comments
 (0)