Skip to content

Commit dcff989

Browse files
committed
Add torchao conversion
1 parent b3f3111 commit dcff989

File tree

5 files changed

+59
-12
lines changed

5 files changed

+59
-12
lines changed

.ci/scripts/test_torchao_huggingface_checkpoints.sh

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ set -euxo pipefail
55
# Args / flags
66
# -------------------------
77
TEST_WITH_RUNNER=0
8+
USE_TORCHAO_KERNELS=0
89
MODEL_NAME=""
910

1011
# Parse args
@@ -22,10 +23,14 @@ while [[ $# -gt 0 ]]; do
2223
--test_with_runner)
2324
TEST_WITH_RUNNER=1
2425
;;
26+
--use_torchao_kernels)
27+
USE_TORCHAO_KERNELS=1
28+
;;
2529
-h|--help)
26-
echo "Usage: $0 <model_name> [--test_with_runner]"
30+
echo "Usage: $0 <model_name> [--test_with_runner] [--use_torchao_kernels]"
2731
echo " model_name: qwen3_4b | phi_4_mini"
2832
echo " --test_with_runner: build ET + run llama_main to sanity-check the export"
33+
echo " --use_torchao_kernels: use torchao kernels for linear and tied embedding"
2934
exit 0
3035
;;
3136
*)
@@ -42,6 +47,13 @@ fi
4247

4348
MODEL_OUT=model.pte
4449

50+
51+
# Default to XNNPACK
52+
BACKEND_ARGS="-X --xnnpack-extended-ops"
53+
if [[ "$USE_TORCHAO_KERNELS" -eq 1 ]]; then
54+
BACKEND_ARGS="--torchao-kernels"
55+
fi
56+
4557
case "$MODEL_NAME" in
4658
qwen3_4b)
4759
echo "Running Qwen3-4B export..."
@@ -58,12 +70,12 @@ case "$MODEL_NAME" in
5870
--output_name $MODEL_OUT \
5971
-kv \
6072
--use_sdpa_with_kv_cache \
61-
-X \
62-
--xnnpack-extended-ops \
6373
--max_context_length 1024 \
6474
--max_seq_length 1024 \
75+
--metadata '{"get_bos_id":199999, "get_eos_ids":[200020,199999]}' \
76+
--verbose \
6577
--dtype fp32 \
66-
--metadata '{"get_bos_id":199999, "get_eos_ids":[200020,199999]}'
78+
${BACKEND_ARGS}
6779
;;
6880

6981
phi_4_mini)
@@ -81,12 +93,12 @@ case "$MODEL_NAME" in
8193
--output_name $MODEL_OUT \
8294
-kv \
8395
--use_sdpa_with_kv_cache \
84-
-X \
85-
--xnnpack-extended-ops \
8696
--max_context_length 1024 \
8797
--max_seq_length 1024 \
98+
--metadata '{"get_bos_id":199999, "get_eos_ids":[200020,199999]}' \
99+
--verbose \
88100
--dtype fp32 \
89-
--metadata '{"get_bos_id":199999, "get_eos_ids":[200020,199999]}'
101+
${BACKEND_ARGS}
90102
;;
91103

92104
*)
@@ -120,6 +132,7 @@ if [[ "$TEST_WITH_RUNNER" -eq 1 ]]; then
120132
-DEXECUTORCH_BUILD_EXTENSION_LLM_RUNNER=ON \
121133
-DEXECUTORCH_BUILD_EXTENSION_LLM=ON \
122134
-DEXECUTORCH_BUILD_KERNELS_LLM=ON \
135+
-DEXECUTORCH_BUILD_KERNELS_TORCHAO=ON \
123136
-Bcmake-out .
124137
cmake --build cmake-out -j16 --config Release --target install
125138

.github/workflows/trunk.yml

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -594,15 +594,24 @@ jobs:
594594
strategy:
595595
matrix:
596596
model: [qwen3_4b, phi_4_mini]
597+
runner: [linux.2xlarge, linux.arm64.2xlarge]
598+
docker-image: [executorch-ubuntu-22.04-clang12, executorch-ubuntu-22.04-gcc11-aarch64]
597599
include:
598600
- model: qwen3_4b
599601
test_with_runner: true
600602
- model: phi_4_mini
601603
test_with_runner: false
604+
- runner: linux.2xlarge
605+
use_torchao_kernels: false
606+
- runner: linux.arm64.2xlarge
607+
use_torchao_kernels: true
608+
exclude:
609+
- runner: linux.2xlarge
610+
docker-image: executorch-ubuntu-22.04-gcc11-aarch64
611+
- runner: linux.arm64.2xlarge
612+
docker-image: executorch-ubuntu-22.04-clang12
602613
fail-fast: false
603614
with:
604-
runner: linux.2xlarge
605-
docker-image: ci-image:executorch-ubuntu-22.04-clang12
606615
submodules: 'recursive'
607616
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
608617
timeout: 900
@@ -611,10 +620,10 @@ jobs:
611620
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
612621
conda activate "${CONDA_ENV}"
613622
614-
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh --build-tool cmake
623+
PYTHON_EXECUTABLE=python EXECUTORCH_BUILD_KERNELS_TORCHAO=1 TORCHAO_BUILD_KLEIDIAI=1 bash .ci/scripts/setup-linux.sh --build-tool cmake
615624
pip install -U "huggingface_hub[cli]"
616625
617-
bash .ci/scripts/test_torchao_huggingface_checkpoints.sh ${{ matrix.model }} ${{ matrix.test_with_runner && '--test_with_runner' || '' }}
626+
bash .ci/scripts/test_torchao_huggingface_checkpoints.sh ${{ matrix.model }} ${{ matrix.test_with_runner && '--test_with_runner' || '' }} ${{ matrix.use_torchao_kernels && '--use_torchao_kernels' || '' }}
618627
619628
test-multimodal-macos:
620629
if: ${{ !github.event.pull_request.head.repo.fork }}

examples/models/llama/export_llama_lib.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,11 @@ def build_args_parser() -> argparse.ArgumentParser:
417417
action="store_true",
418418
help="Delegate more operators beyond DQLinear to the xnnpack backend. Requires -X or --xnnpack to be set.",
419419
)
420+
parser.add_argument(
421+
"--torchao-kernels",
422+
action="store_true",
423+
help="Delegate tied-embedding and quantized linear ops to torchao kernels",
424+
)
420425
parser.add_argument("-V", "--vulkan", action="store_true")
421426
parser.add_argument("--vulkan-force-fp16", action="store_true")
422427
parser.add_argument("--mps", action="store_true")
@@ -741,6 +746,7 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
741746
preq_group_size=llm_config.base.preq_group_size,
742747
preq_embedding_quantize=llm_config.base.preq_embedding_quantize,
743748
local_global_attention=llm_config.model.local_global_attention,
749+
use_torchao_kernels=llm_config.backend.torchao.enabled,
744750
)
745751
)
746752

@@ -1303,6 +1309,7 @@ def _get_source_transforms( # noqa
13031309
preq_group_size: Optional[int] = None,
13041310
preq_embedding_quantize: Optional[str] = None,
13051311
local_global_attention: Optional[List[int]] = None,
1312+
use_torchao_kernels: bool = False,
13061313
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
13071314
"""
13081315
Return a list of functions that transform a graph.
@@ -1475,6 +1482,11 @@ def _get_source_transforms( # noqa
14751482
)
14761483
)
14771484

1485+
if use_torchao_kernels:
1486+
from torchao.prototype.tensor_conversion.api import _convert_model_for_aarch64
1487+
1488+
transforms.append(_convert_model_for_aarch64)
1489+
14781490
return transforms
14791491

14801492

extension/llm/export/config/llm_config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,15 @@ class MPSConfig:
452452
enabled: bool = False
453453

454454

455+
@dataclass
456+
class TorchAOKernelsConfig:
457+
"""
458+
Configures the torchao-kernels backend.
459+
"""
460+
461+
enabled: bool = False
462+
463+
455464
@dataclass
456465
class BackendConfig:
457466
"""
@@ -464,6 +473,7 @@ class BackendConfig:
464473
vulkan: VulkanConfig = field(default_factory=VulkanConfig)
465474
qnn: QNNConfig = field(default_factory=QNNConfig)
466475
mps: MPSConfig = field(default_factory=MPSConfig)
476+
torchao: TorchAOKernelsConfig = field(default_factory=TorchAOKernelsConfig)
467477

468478

469479
################################################################################
@@ -632,6 +642,9 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
632642
if hasattr(args, "mps"):
633643
llm_config.backend.mps.enabled = args.mps
634644

645+
if hasattr(args, "torchao_kernels"):
646+
llm_config.backend.torchao.enabled = args.torchao_kernels
647+
635648
# DebugConfig
636649
if hasattr(args, "profile_memory"):
637650
llm_config.debug.profile_memory = args.profile_memory

third-party/ao

Submodule ao updated 145 files

0 commit comments

Comments
 (0)