Skip to content

Commit 5d29a7d

Browse files
authored
Add torchao conversion (#14545)
This adds a new "torchao" backend for pre-quantized checkpoints. Pre-quantized checkpoints can be lowered to a backend (e.g., XNNPACK) by specifying "-X" in etLLM. With this PR, we can now lower pre-quantized checkpoints to torchao lowbit kernels by specifying "--torchao_kernels" in the export script instead of "-X". Note this will run both linear and tied_embedding kernels with torchao_kernels. If you want to run linear with XNNPACK, but only run tied embedding with torchao, use "--torchao_kernels_tied_embedding" and "-X". New CI tests are added for the flow.
1 parent f662cf5 commit 5d29a7d

File tree

5 files changed

+105
-13
lines changed

5 files changed

+105
-13
lines changed

.ci/scripts/test_torchao_huggingface_checkpoints.sh

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ set -euxo pipefail
55
# Args / flags
66
# -------------------------
77
TEST_WITH_RUNNER=0
8+
USE_TORCHAO_KERNELS=0
89
MODEL_NAME=""
910

1011
# Parse args
@@ -22,10 +23,14 @@ while [[ $# -gt 0 ]]; do
2223
--test_with_runner)
2324
TEST_WITH_RUNNER=1
2425
;;
26+
--use_torchao_kernels)
27+
USE_TORCHAO_KERNELS=1
28+
;;
2529
-h|--help)
26-
echo "Usage: $0 <model_name> [--test_with_runner]"
30+
echo "Usage: $0 <model_name> [--test_with_runner] [--use_torchao_kernels]"
2731
echo " model_name: qwen3_4b | phi_4_mini"
2832
echo " --test_with_runner: build ET + run llama_main to sanity-check the export"
33+
echo " --use_torchao_kernels: use torchao kernels for linear and tied embedding"
2934
exit 0
3035
;;
3136
*)
@@ -42,6 +47,13 @@ fi
4247

4348
MODEL_OUT=model.pte
4449

50+
51+
# Default to XNNPACK
52+
BACKEND_ARGS="-X --xnnpack-extended-ops"
53+
if [[ "$USE_TORCHAO_KERNELS" -eq 1 ]]; then
54+
BACKEND_ARGS="--use-torchao-kernels"
55+
fi
56+
4557
case "$MODEL_NAME" in
4658
qwen3_4b)
4759
echo "Running Qwen3-4B export..."
@@ -58,12 +70,12 @@ case "$MODEL_NAME" in
5870
--output_name $MODEL_OUT \
5971
-kv \
6072
--use_sdpa_with_kv_cache \
61-
-X \
62-
--xnnpack-extended-ops \
6373
--max_context_length 1024 \
6474
--max_seq_length 1024 \
75+
--metadata '{"get_bos_id":199999, "get_eos_ids":[200020,199999]}' \
76+
--verbose \
6577
--dtype fp32 \
66-
--metadata '{"get_bos_id":199999, "get_eos_ids":[200020,199999]}'
78+
${BACKEND_ARGS}
6779
;;
6880

6981
phi_4_mini)
@@ -81,12 +93,12 @@ case "$MODEL_NAME" in
8193
--output_name $MODEL_OUT \
8294
-kv \
8395
--use_sdpa_with_kv_cache \
84-
-X \
85-
--xnnpack-extended-ops \
8696
--max_context_length 1024 \
8797
--max_seq_length 1024 \
98+
--metadata '{"get_bos_id":199999, "get_eos_ids":[200020,199999]}' \
99+
--verbose \
88100
--dtype fp32 \
89-
--metadata '{"get_bos_id":199999, "get_eos_ids":[200020,199999]}'
101+
${BACKEND_ARGS}
90102
;;
91103

92104
*)
@@ -104,6 +116,10 @@ if [[ $MODEL_SIZE -gt $EXPECTED_MODEL_SIZE_UPPER_BOUND ]]; then
104116
fi
105117

106118
# Install ET with CMake
119+
EXECUTORCH_BUILD_KERNELS_TORCHAO="OFF"
120+
if [[ "$USE_TORCHAO_KERNELS" -eq 1 ]]; then
121+
EXECUTORCH_BUILD_KERNELS_TORCHAO="ON"
122+
fi
107123
if [[ "$TEST_WITH_RUNNER" -eq 1 ]]; then
108124
echo "[runner] Building and testing llama_main ..."
109125
cmake -DPYTHON_EXECUTABLE=python \
@@ -120,6 +136,7 @@ if [[ "$TEST_WITH_RUNNER" -eq 1 ]]; then
120136
-DEXECUTORCH_BUILD_EXTENSION_LLM_RUNNER=ON \
121137
-DEXECUTORCH_BUILD_EXTENSION_LLM=ON \
122138
-DEXECUTORCH_BUILD_KERNELS_LLM=ON \
139+
-DEXECUTORCH_BUILD_KERNELS_TORCHAO=${EXECUTORCH_BUILD_KERNELS_TORCHAO} \
123140
-Bcmake-out .
124141
cmake --build cmake-out -j16 --config Release --target install
125142

.github/workflows/trunk.yml

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -594,15 +594,22 @@ jobs:
594594
strategy:
595595
matrix:
596596
model: [qwen3_4b, phi_4_mini]
597+
runner: [linux.2xlarge]
598+
docker-image: [executorch-ubuntu-22.04-clang12]
599+
backend: [xnnpack]
597600
include:
598601
- model: qwen3_4b
599-
test_with_runner: true
602+
runner: linux.arm64.2xlarge
603+
docker-image: executorch-ubuntu-22.04-gcc11-aarch64
604+
backend: torchao
600605
- model: phi_4_mini
601-
test_with_runner: false
606+
runner: linux.arm64.2xlarge
607+
docker-image: executorch-ubuntu-22.04-gcc11-aarch64
608+
backend: torchao
602609
fail-fast: false
603610
with:
604-
runner: linux.2xlarge
605-
docker-image: ci-image:executorch-ubuntu-22.04-clang12
611+
runner: ${{ matrix.runner }}
612+
docker-image: ci-image:${{ matrix.docker-image }}
606613
submodules: 'recursive'
607614
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
608615
timeout: 900
@@ -612,9 +619,14 @@ jobs:
612619
conda activate "${CONDA_ENV}"
613620
614621
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh --build-tool cmake
622+
623+
if [[ "${{ matrix.backend }}" == "torchao" ]]; then
624+
BUILD_TORCHAO_EXPERIMENTAL=1 TORCHAO_BUILD_CPU_AARCH64=1 TORCHAO_BUILD_KLEIDIAI=1 TORCHAO_ENABLE_ARM_NEON_DOT=1 TORCHAO_PARALLEL_BACKEND=OPENMP pip install third-party/ao
625+
fi
626+
615627
pip install -U "huggingface_hub[cli]"
616628
617-
bash .ci/scripts/test_torchao_huggingface_checkpoints.sh ${{ matrix.model }} ${{ matrix.test_with_runner && '--test_with_runner' || '' }}
629+
bash .ci/scripts/test_torchao_huggingface_checkpoints.sh ${{ matrix.model }} ${{ matrix.model != 'phi_4_mini' && '--test_with_runner' || '' }} ${{ matrix.backend == 'torchao' && '--use_torchao_kernels' || '' }}
618630
619631
test-multimodal-macos:
620632
if: ${{ !github.event.pull_request.head.repo.fork }}

examples/models/llama/export_llama_lib.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,21 @@ def build_args_parser() -> argparse.ArgumentParser:
417417
action="store_true",
418418
help="Delegate more operators beyond DQLinear to the xnnpack backend. Requires -X or --xnnpack to be set.",
419419
)
420+
parser.add_argument(
421+
"--use-torchao-kernels",
422+
action="store_true",
423+
help="Delegate tied-embedding and quantized linear ops to torchao kernels",
424+
)
425+
parser.add_argument(
426+
"--use-torchao-kernels-tied-embedding",
427+
action="store_true",
428+
help="Delegate tied-embedding ops to torchao kernels",
429+
)
430+
parser.add_argument(
431+
"--use-torchao-kernels-linear",
432+
action="store_true",
433+
help="Delegate linear ops to torchao kernels",
434+
)
420435
parser.add_argument("-V", "--vulkan", action="store_true")
421436
parser.add_argument("--vulkan-force-fp16", action="store_true")
422437
parser.add_argument("--mps", action="store_true")
@@ -741,6 +756,8 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
741756
preq_group_size=llm_config.base.preq_group_size,
742757
preq_embedding_quantize=llm_config.base.preq_embedding_quantize,
743758
local_global_attention=llm_config.model.local_global_attention,
759+
use_torchao_kernels_linear=llm_config.backend.torchao.use_torchao_kernels_linear,
760+
use_torchao_kernels_tied_embedding=llm_config.backend.torchao.use_torchao_kernels_tied_embedding,
744761
)
745762
)
746763

@@ -1303,6 +1320,8 @@ def _get_source_transforms( # noqa
13031320
preq_group_size: Optional[int] = None,
13041321
preq_embedding_quantize: Optional[str] = None,
13051322
local_global_attention: Optional[List[int]] = None,
1323+
use_torchao_kernels_linear: bool = False,
1324+
use_torchao_kernels_tied_embedding: bool = False,
13061325
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
13071326
"""
13081327
Return a list of functions that transform a graph.
@@ -1475,6 +1494,17 @@ def _get_source_transforms( # noqa
14751494
)
14761495
)
14771496

1497+
if any([use_torchao_kernels_linear, use_torchao_kernels_tied_embedding]):
1498+
from torchao.prototype.tensor_conversion.api import _convert_model_for_aarch64
1499+
1500+
transforms.append(
1501+
partial(
1502+
_convert_model_for_aarch64,
1503+
convert_linear=use_torchao_kernels_linear,
1504+
convert_tied_embedding=use_torchao_kernels_tied_embedding,
1505+
)
1506+
)
1507+
14781508
return transforms
14791509

14801510

extension/llm/export/config/llm_config.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,16 @@ class MPSConfig:
452452
enabled: bool = False
453453

454454

455+
@dataclass
456+
class TorchAOKernelsConfig:
457+
"""
458+
Configures the torchao-kernels backend.
459+
"""
460+
461+
use_torchao_kernels_linear: bool = False
462+
use_torchao_kernels_tied_embedding: bool = False
463+
464+
455465
@dataclass
456466
class BackendConfig:
457467
"""
@@ -464,6 +474,7 @@ class BackendConfig:
464474
vulkan: VulkanConfig = field(default_factory=VulkanConfig)
465475
qnn: QNNConfig = field(default_factory=QNNConfig)
466476
mps: MPSConfig = field(default_factory=MPSConfig)
477+
torchao: TorchAOKernelsConfig = field(default_factory=TorchAOKernelsConfig)
467478

468479

469480
################################################################################
@@ -632,6 +643,28 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
632643
if hasattr(args, "mps"):
633644
llm_config.backend.mps.enabled = args.mps
634645

646+
# TorchAoKernels
647+
if any(
648+
hasattr(args, a)
649+
for a in [
650+
"use_torchao_kernels",
651+
"use_torchao_kernels_linear",
652+
"use_torchao_kernels_tied_embedding",
653+
]
654+
):
655+
if hasattr(args, "use_torchao_kernels") and args.use_torchao_kernels:
656+
# Enable all conversions if torchao_kernels is specified
657+
llm_config.backend.torchao.use_torchao_kernels_linear = True
658+
llm_config.backend.torchao.use_torchao_kernels_tied_embedding = True
659+
else:
660+
# Otherwise, only enable the conversions that are specified
661+
llm_config.backend.torchao.use_torchao_kernels_linear = getattr(
662+
args, "use_torchao_kernels_linear", False
663+
)
664+
llm_config.backend.torchao.use_torchao_kernels_tied_embedding = getattr(
665+
args, "use_torchao_kernels_tied_embedding", False
666+
)
667+
635668
# DebugConfig
636669
if hasattr(args, "profile_memory"):
637670
llm_config.debug.profile_memory = args.profile_memory

third-party/ao

Submodule ao updated 146 files

0 commit comments

Comments
 (0)