|
14 | 14 | from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 |
15 | 15 |
|
16 | 16 |
|
17 | | -try: |
18 | | - op = torch.ops.quantized_decomposed.quantize_per_token.out |
19 | | - assert op is not None |
20 | | -except: |
21 | | - import glob |
22 | | - |
23 | | - import executorch |
24 | | - |
25 | | - from executorch.extension.pybindings import portable_lib # noqa # usort: skip |
26 | | - |
27 | | - # Ideally package is installed in only one location but usage of |
28 | | - # PYATHONPATH can result in multiple locations. |
29 | | - # ATM this is mainly used in CI for qnn runner. Will need to revisit this |
30 | | - executorch_package_path = executorch.__path__[-1] |
31 | | - libs = list( |
32 | | - glob.glob( |
33 | | - f"{executorch_package_path}/**/libquantized_ops_aot_lib.*", recursive=True |
34 | | - ) |
35 | | - ) |
36 | | - assert len(libs) == 1, f"Expected 1 library but got {len(libs)}" |
37 | | - logging.info(f"Loading custom ops library: {libs[0]}") |
38 | | - torch.ops.load_library(libs[0]) |
39 | | - op = torch.ops.quantized_decomposed.quantize_per_token.out |
40 | | - assert op is not None |
41 | 17 |
|
42 | 18 | """ |
43 | 19 | Heavily "inspired" by AO's implementation of the same in torchao/_models/llama/model.py |
@@ -247,6 +223,28 @@ def from_float(cls, kv_cache, cache_type: QuantizedCacheType): |
247 | 223 |
|
248 | 224 |
|
249 | 225 | def replace_kv_cache_with_quantized_kv_cache(module): |
| 226 | + try: |
| 227 | + op = torch.ops.quantized_decomposed.quantize_per_token.out |
| 228 | + assert op is not None |
| 229 | + except: |
| 230 | + import glob |
| 231 | + import executorch |
| 232 | + from executorch.extension.pybindings import portable_lib # noqa # usort: skip |
| 233 | + |
| 234 | + # Ideally package is installed in only one location but usage of |
| 235 | + # PYATHONPATH can result in multiple locations. |
| 236 | + # ATM this is mainly used in CI for qnn runner. Will need to revisit this |
| 237 | + executorch_package_path = executorch.__path__[-1] |
| 238 | + libs = list( |
| 239 | + glob.glob( |
| 240 | + f"{executorch_package_path}/**/libquantized_ops_aot_lib.*", recursive=True |
| 241 | + ) |
| 242 | + ) |
| 243 | + assert len(libs) == 1, f"Expected 1 library but got {len(libs)}" |
| 244 | + logging.info(f"Loading custom ops library: {libs[0]}") |
| 245 | + torch.ops.load_library(libs[0]) |
| 246 | + op = torch.ops.quantized_decomposed.quantize_per_token.out |
| 247 | + assert op is not None |
250 | 248 | # This is needed to ensure that custom ops are registered |
251 | 249 | from executorch.extension.llm.custom_ops import custom_ops # noqa: F401 |
252 | 250 |
|
|
0 commit comments