Skip to content

Commit e49b3ad

Browse files
committed
Update on "[Executorch] Add quantized kv cache to oss ci"
Fixes to make sure quantized kv cache works in oss Differential Revision: [D66269487](https://our.internmc.facebook.com/intern/diff/D66269487/) [ghstack-poisoned]
2 parents 082b308 + e6d66f2 commit e49b3ad

File tree

2 files changed

+24
-18
lines changed

2 files changed

+24
-18
lines changed

examples/models/llama/source_transformation/quantized_kv_cache.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,24 @@
1818
op = torch.ops.quantized_decomposed.quantize_per_token.out
1919
assert op is not None
2020
except:
21-
import executorch.kernels.quantized # noqa: F401
21+
import glob
22+
23+
import executorch
24+
25+
from executorch.extension.pybindings import portable_lib # noqa # usort: skip
26+
27+
# Ideally package is installed in only one location but usage of
28+
# PYATHONPATH can result in multiple locations.
29+
# ATM this is mainly used in CI for qnn runner. Will need to revisit this
30+
executorch_package_path = executorch.__path__[-1]
31+
libs = list(
32+
glob.glob(
33+
f"{executorch_package_path}/**/libquantized_ops_aot_lib.*", recursive=True
34+
)
35+
)
36+
assert len(libs) == 1, f"Expected 1 library but got {len(libs)}"
37+
logging.info(f"Loading custom ops library: {libs[0]}")
38+
torch.ops.load_library(libs[0])
2239
op = torch.ops.quantized_decomposed.quantize_per_token.out
2340
assert op is not None
2441

@@ -230,8 +247,8 @@ def from_float(cls, kv_cache, cache_type: QuantizedCacheType):
230247

231248

232249
def replace_kv_cache_with_quantized_kv_cache(module):
233-
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401
234-
250+
# This is needed to ensure that custom ops are registered
251+
from executorch.extension.pybindings import portable_lib # noqa # usort: skip
235252
logging.warning(
236253
"Replacing KVCache with QuantizedKVCache. This modifies the model in place."
237254
)

kernels/quantized/__init__.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,14 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import logging
8-
97
try:
10-
import glob
8+
from pathlib import Path
119

10+
libs = list(Path(__file__).parent.resolve().glob("**/libquantized_ops_aot_lib.*"))
11+
del Path
12+
assert len(libs) == 1, f"Expected 1 library but got {len(libs)}"
1213
import torch as _torch
13-
import executorch
1414

15-
# Ideally package is installed in only one location but usage of
16-
# PYATHONPATH can result in multiple locations.
17-
# ATM this is mainly used in CI for qnn runner. Will need to revisit this
18-
executorch_package_path = executorch.__path__[-1]
19-
libs = list(
20-
glob.glob(
21-
f"{executorch_package_path}/**/libquantized_ops_aot_lib.*", recursive=True
22-
)
23-
)
24-
assert len(libs) == 1, f"Expected 1 library but got {len(libs)}"
25-
logging.info(f"Loading custom ops library: {libs[0]}")
2615
_torch.ops.load_library(libs[0])
2716
del _torch
2817
except:

0 commit comments

Comments
 (0)