Skip to content

Commit e605bf2

Browse files
committed
Update on "[Executorch] Add quantized kv cache to oss ci"
Fixes to make sure quantized kv cache works in oss Differential Revision: [D66269487](https://our.internmc.facebook.com/intern/diff/D66269487/) [ghstack-poisoned]
2 parents b39b8a7 + 032da3e commit e605bf2

File tree

1 file changed

+22
-24
lines changed

1 file changed

+22
-24
lines changed

examples/models/llama/source_transformation/quantized_kv_cache.py

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,30 +14,6 @@
1414
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
1515

1616

17-
try:
18-
op = torch.ops.quantized_decomposed.quantize_per_token.out
19-
assert op is not None
20-
except:
21-
import glob
22-
23-
import executorch
24-
25-
from executorch.extension.pybindings import portable_lib # noqa # usort: skip
26-
27-
# Ideally package is installed in only one location but usage of
28-
# PYATHONPATH can result in multiple locations.
29-
# ATM this is mainly used in CI for qnn runner. Will need to revisit this
30-
executorch_package_path = executorch.__path__[-1]
31-
libs = list(
32-
glob.glob(
33-
f"{executorch_package_path}/**/libquantized_ops_aot_lib.*", recursive=True
34-
)
35-
)
36-
assert len(libs) == 1, f"Expected 1 library but got {len(libs)}"
37-
logging.info(f"Loading custom ops library: {libs[0]}")
38-
torch.ops.load_library(libs[0])
39-
op = torch.ops.quantized_decomposed.quantize_per_token.out
40-
assert op is not None
4117

4218
"""
4319
Heavily "inspired" by AO's implementation of the same in torchao/_models/llama/model.py
@@ -247,6 +223,28 @@ def from_float(cls, kv_cache, cache_type: QuantizedCacheType):
247223

248224

249225
def replace_kv_cache_with_quantized_kv_cache(module):
226+
try:
227+
op = torch.ops.quantized_decomposed.quantize_per_token.out
228+
assert op is not None
229+
except:
230+
import glob
231+
import executorch
232+
from executorch.extension.pybindings import portable_lib # noqa # usort: skip
233+
234+
# Ideally package is installed in only one location but usage of
235+
# PYATHONPATH can result in multiple locations.
236+
# ATM this is mainly used in CI for qnn runner. Will need to revisit this
237+
executorch_package_path = executorch.__path__[-1]
238+
libs = list(
239+
glob.glob(
240+
f"{executorch_package_path}/**/libquantized_ops_aot_lib.*", recursive=True
241+
)
242+
)
243+
assert len(libs) == 1, f"Expected 1 library but got {len(libs)}"
244+
logging.info(f"Loading custom ops library: {libs[0]}")
245+
torch.ops.load_library(libs[0])
246+
op = torch.ops.quantized_decomposed.quantize_per_token.out
247+
assert op is not None
250248
# This is needed to ensure that custom ops are registered
251249
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401
252250

0 commit comments

Comments
 (0)