Skip to content

Commit 6f1efc5

Browse files
committed
Update on "[Executorch] Add quantized kv cache to oss ci"
Fixes to make sure quantized kv cache works in oss Differential Revision: [D66269487](https://our.internmc.facebook.com/intern/diff/D66269487/) [ghstack-poisoned]
2 parents 643086c + 9cd62d5 commit 6f1efc5

File tree

2 files changed

+18
-23
lines changed

2 files changed

+18
-23
lines changed

examples/models/llama/source_transformation/quantized_kv_cache.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,7 @@
1818
op = torch.ops.quantized_decomposed.quantize_per_token.out
1919
assert op is not None
2020
except:
21-
import glob
22-
23-
import executorch
24-
25-
# Ideally package is installed in only one location but usage of
26-
# PYATHONPATH can result in multiple locations.
27-
# ATM this is mainly used in CI for qnn runner. Will need to revisit this
28-
executorch_package_path = executorch.__path__[-1]
29-
libs = list(
30-
glob.glob(
31-
f"{executorch_package_path}/**/libquantized_ops_aot_lib.*", recursive=True
32-
)
33-
)
34-
assert len(libs) == 1, f"Expected 1 library but got {len(libs)}"
35-
logging.info(f"Loading custom ops library: {libs[0]}")
36-
torch.ops.load_library(libs[0])
21+
import executorch.kernels.quantized # noqa: F401
3722
op = torch.ops.quantized_decomposed.quantize_per_token.out
3823
assert op is not None
3924

kernels/quantized/__init__.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,26 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import logging
8+
79
try:
8-
from pathlib import Path
10+
import glob
911

10-
libs = list(Path(__file__).parent.resolve().glob("**/libquantized_ops_aot_lib.*"))
11-
del Path
12-
assert len(libs) == 1, f"Expected 1 library but got {len(libs)}"
13-
import torch as _torch
12+
import torch
13+
import executorch
1414

15-
_torch.ops.load_library(libs[0])
16-
del _torch
15+
# Ideally package is installed in only one location but usage of
16+
# PYATHONPATH can result in multiple locations.
17+
# ATM this is mainly used in CI for qnn runner. Will need to revisit this
18+
executorch_package_path = executorch.__path__[-1]
19+
libs = list(
20+
glob.glob(
21+
f"{executorch_package_path}/**/libquantized_ops_aot_lib.*", recursive=True
22+
)
23+
)
24+
assert len(libs) == 1, f"Expected 1 library but got {len(libs)}"
25+
logging.info(f"Loading custom ops library: {libs[0]}")
26+
torch.ops.load_library(libs[0])
1727
except:
1828
import logging
1929

0 commit comments

Comments
 (0)