Skip to content

Commit a943088

Browse files
committed
Update on "[Executorch] Add quantized kv cache to oss ci"
Fixes to make sure quantized kv cache works in oss Differential Revision: [D66269487](https://our.internmc.facebook.com/intern/diff/D66269487/) [ghstack-poisoned]
2 parents 128e461 + ba7d02e commit a943088

File tree

3 files changed

+13
-15
lines changed

3 files changed

+13
-15
lines changed

examples/models/llama/source_transformation/quantized_kv_cache.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,29 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from pathlib import Path
78
import logging
89
from enum import Enum
910

10-
import executorch.extension.llm.custom_ops # noqa: F401
11+
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401
1112

1213
import torch
1314
import torch.nn as nn
1415
from executorch.examples.models.llama.llama_transformer import KVCache
1516
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
1617

1718

19+
try:
20+
op = torch.ops.quantized_decomposed.quantize_per_token
21+
assert op is not None
22+
except:
23+
libs = list(Path(__file__).parent.resolve().glob("libquantized_ops_aot_lib.*"))
24+
assert len(libs) == 1, f"Expected 1 library but got {len(libs)}"
25+
logging.info(f"Loading custom ops library: {libs[0]}")
26+
torch.ops.load_library(libs[0])
27+
op = torch.ops.quantized_decomposed.quantize_per_token
28+
assert op is not None
29+
1830
"""
1931
Heavily "inspired" by AO's implementation of the same in torchao/_models/llama/model.py
2032
"""

exir/passes/_quant_patterns_and_replacements.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -192,19 +192,6 @@ def embedding_byte_dtype_out_meta(
192192
"int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)",
193193
)
194194

195-
# TODO: move these registrations to pytorch core
196-
quantized_decomposed_lib.define(
197-
"quantize_per_token.out(Tensor input, Tensor scales, Tensor zero_points, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)",
198-
)
199-
200-
quantized_decomposed_lib.define(
201-
"dequantize_per_token.out(Tensor input, Tensor scales, Tensor zero_points, int quant_min, int quant_max, ScalarType dtype, ScalarType output_dtype, *, Tensor(a!) out) -> Tensor(a!)",
202-
)
203-
204-
quantized_decomposed_lib.define(
205-
"choose_qparams_per_token_asymmetric.out(Tensor input, ScalarType dtype, *, Tensor(a!) scale_out, Tensor(b!) zero_point_out) -> (Tensor(a!), Tensor(b!))",
206-
)
207-
208195

209196
@impl(quantized_decomposed_lib, "embedding_2bit", "CompositeExplicitAutograd")
210197
def embedding_2bit(

extension/llm/custom_ops/custom_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
from torch.library import impl
1919

20-
# TODO rename this file to custom_ops_meta_registration.py
2120
try:
2221
op = torch.ops.llama.sdpa_with_kv_cache.default
2322
assert op is not None

0 commit comments

Comments
 (0)