Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .ci/scripts/test_llama.sh
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,12 @@ else
COREML=OFF
fi

if [[ "${MODE}" =~ .*quantize_kv.* ]]; then
QUANTIZE_KV_CACHE=ON
else
QUANTIZE_KV_CACHE=OFF
fi

echo "COREML option ${COREML}"

if [[ "${MODE}" =~ .*qnn.* ]]; then
Expand Down Expand Up @@ -249,6 +255,9 @@ if [[ "${QNN}" == "ON" ]]; then
EXPORT_ARGS+=" --tokenizer_path tokenizer.model --pt2e_quantize qnn_16a16w --calibration_tasks wikitext --calibration_limit 1 --calibration_seq_length 128 --calibration_data Once "
fi
fi
if [[ "${QUANTIZE_KV_CACHE}" == "ON" ]]; then
EXPORT_ARGS="${EXPORT_ARGS} --quantize_kv_cache"
fi
# Add dynamically linked library location
$PYTHON_EXECUTABLE -m examples.models.llama.export_llama ${EXPORT_ARGS}

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ jobs:
strategy:
matrix:
dtype: [fp32]
mode: [portable, xnnpack+custom, xnnpack+custom+qe]
mode: [portable, xnnpack+custom, xnnpack+custom+qe,xnnpack+custom+quantize_kv,xnnpack+quantize_kv]
include:
- dtype: bf16
mode: portable
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/trunk.yml
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ jobs:
strategy:
matrix:
dtype: [fp32]
mode: [portable, xnnpack+kv+custom, mps, coreml]
mode: [portable, xnnpack+kv+custom, mps, coreml, xnnpack+custom+quantize_kv]
include:
- dtype: bf16
mode: portable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
import torch.nn as nn
from executorch.examples.models.llama.llama_transformer import KVCache

from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401


Expand Down Expand Up @@ -221,6 +222,33 @@ def from_float(cls, kv_cache, cache_type: QuantizedCacheType):


def replace_kv_cache_with_quantized_kv_cache(module):
try:
op = torch.ops.quantized_decomposed.quantize_per_token.out
assert op is not None
except:
import glob

import executorch
from executorch.extension.pybindings import portable_lib # noqa # usort: skip

# Ideally package is installed in only one location but usage of
# PYATHONPATH can result in multiple locations.
# ATM this is mainly used in CI for qnn runner. Will need to revisit this
executorch_package_path = executorch.__path__[-1]
libs = list(
glob.glob(
f"{executorch_package_path}/**/libquantized_ops_aot_lib.*",
recursive=True,
)
)
assert len(libs) == 1, f"Expected 1 library but got {len(libs)}"
logging.info(f"Loading custom ops library: {libs[0]}")
torch.ops.load_library(libs[0])
op = torch.ops.quantized_decomposed.quantize_per_token.out
assert op is not None
# This is needed to ensure that custom ops are registered
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401

logging.warning(
"Replacing KVCache with QuantizedKVCache. This modifies the model in place."
)
Expand Down
2 changes: 1 addition & 1 deletion examples/models/llama/source_transformation/sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def forward(

k_cache = self.kv_cache.k_cache
v_cache = self.kv_cache.v_cache
if isinstance(self.kv_cache, QuantizedKVCache):
if hasattr(self.kv_cache, "quantized_cache_dtype"):
# updated quantize cache, scale and zero points
# returns dequantized kv cache
# Not most optimal. Optimizations to follow next
Expand Down
8 changes: 7 additions & 1 deletion extension/llm/custom_ops/custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@

import executorch

executorch_package_path = executorch.__path__[0]
# This is needed to ensure that custom ops are registered
from executorch.extension.pybindings import portable_lib # noqa # usort: skip

# Ideally package is installed in only one location but usage of
# PYATHONPATH can result in multiple locations.
# ATM this is mainly used in CI for qnn runner. Will need to revisit this
executorch_package_path = executorch.__path__[-1]
logging.info(f"Looking for libcustom_ops_aot_lib.so in {executorch_package_path}")
libs = list(
glob.glob(
Expand Down
3 changes: 3 additions & 0 deletions kernels/quantized/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,17 @@ if(NOT CMAKE_GENERATOR STREQUAL "Xcode"
set(_quantized_aot_ops
"quantized_decomposed::add.out"
"quantized_decomposed::choose_qparams.Tensor_out"
"quantized_decomposed::choose_qparams_per_token_asymmetric.out"
"quantized_decomposed::dequantize_per_channel.out"
"quantized_decomposed::dequantize_per_tensor.out"
"quantized_decomposed::dequantize_per_tensor.Tensor_out"
"quantized_decomposed::dequantize_per_token.out"
"quantized_decomposed::mixed_linear.out"
"quantized_decomposed::mixed_mm.out"
"quantized_decomposed::quantize_per_channel.out"
"quantized_decomposed::quantize_per_tensor.out"
"quantized_decomposed::quantize_per_tensor.Tensor_out"
"quantized_decomposed::quantize_per_token.out"
)
gen_selected_ops(
LIB_NAME "quantized_ops_aot_lib" ROOT_OPS ${_quantized_aot_ops}
Expand Down
Loading