Skip to content

Commit d88d480

Browse files
authored
Merge branch 'main' into gh/larryliu0820/77/base
2 parents 56b6ac6 + 52330d5 commit d88d480

File tree

14 files changed

+524
-40
lines changed

14 files changed

+524
-40
lines changed

.ci/scripts/test_torchao_huggingface_checkpoints.sh

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ set -euxo pipefail
55
# Args / flags
66
# -------------------------
77
TEST_WITH_RUNNER=0
8+
USE_TORCHAO_KERNELS=0
89
MODEL_NAME=""
910

1011
# Parse args
@@ -22,10 +23,14 @@ while [[ $# -gt 0 ]]; do
2223
--test_with_runner)
2324
TEST_WITH_RUNNER=1
2425
;;
26+
--use_torchao_kernels)
27+
USE_TORCHAO_KERNELS=1
28+
;;
2529
-h|--help)
26-
echo "Usage: $0 <model_name> [--test_with_runner]"
30+
echo "Usage: $0 <model_name> [--test_with_runner] [--use_torchao_kernels]"
2731
echo " model_name: qwen3_4b | phi_4_mini"
2832
echo " --test_with_runner: build ET + run llama_main to sanity-check the export"
33+
echo " --use_torchao_kernels: use torchao kernels for linear and tied embedding"
2934
exit 0
3035
;;
3136
*)
@@ -42,6 +47,13 @@ fi
4247

4348
MODEL_OUT=model.pte
4449

50+
51+
# Default to XNNPACK
52+
BACKEND_ARGS="-X --xnnpack-extended-ops"
53+
if [[ "$USE_TORCHAO_KERNELS" -eq 1 ]]; then
54+
BACKEND_ARGS="--use-torchao-kernels"
55+
fi
56+
4557
case "$MODEL_NAME" in
4658
qwen3_4b)
4759
echo "Running Qwen3-4B export..."
@@ -58,12 +70,12 @@ case "$MODEL_NAME" in
5870
--output_name $MODEL_OUT \
5971
-kv \
6072
--use_sdpa_with_kv_cache \
61-
-X \
62-
--xnnpack-extended-ops \
6373
--max_context_length 1024 \
6474
--max_seq_length 1024 \
75+
--metadata '{"get_bos_id":199999, "get_eos_ids":[200020,199999]}' \
76+
--verbose \
6577
--dtype fp32 \
66-
--metadata '{"get_bos_id":199999, "get_eos_ids":[200020,199999]}'
78+
${BACKEND_ARGS}
6779
;;
6880

6981
phi_4_mini)
@@ -81,12 +93,12 @@ case "$MODEL_NAME" in
8193
--output_name $MODEL_OUT \
8294
-kv \
8395
--use_sdpa_with_kv_cache \
84-
-X \
85-
--xnnpack-extended-ops \
8696
--max_context_length 1024 \
8797
--max_seq_length 1024 \
98+
--metadata '{"get_bos_id":199999, "get_eos_ids":[200020,199999]}' \
99+
--verbose \
88100
--dtype fp32 \
89-
--metadata '{"get_bos_id":199999, "get_eos_ids":[200020,199999]}'
101+
${BACKEND_ARGS}
90102
;;
91103

92104
*)
@@ -104,6 +116,10 @@ if [[ $MODEL_SIZE -gt $EXPECTED_MODEL_SIZE_UPPER_BOUND ]]; then
104116
fi
105117

106118
# Install ET with CMake
119+
EXECUTORCH_BUILD_KERNELS_TORCHAO="OFF"
120+
if [[ "$USE_TORCHAO_KERNELS" -eq 1 ]]; then
121+
EXECUTORCH_BUILD_KERNELS_TORCHAO="ON"
122+
fi
107123
if [[ "$TEST_WITH_RUNNER" -eq 1 ]]; then
108124
echo "[runner] Building and testing llama_main ..."
109125
cmake -DPYTHON_EXECUTABLE=python \
@@ -120,6 +136,7 @@ if [[ "$TEST_WITH_RUNNER" -eq 1 ]]; then
120136
-DEXECUTORCH_BUILD_EXTENSION_LLM_RUNNER=ON \
121137
-DEXECUTORCH_BUILD_EXTENSION_LLM=ON \
122138
-DEXECUTORCH_BUILD_KERNELS_LLM=ON \
139+
-DEXECUTORCH_BUILD_KERNELS_TORCHAO=${EXECUTORCH_BUILD_KERNELS_TORCHAO} \
123140
-Bcmake-out .
124141
cmake --build cmake-out -j16 --config Release --target install
125142

.github/workflows/add-unanswered-to-project.yml

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,27 @@ jobs:
2424
"manuelcandales", "metascroy", "cccclai", "rohansjoshi", "kirklandsign", "abhinaykukkadapu", "JacobSzwejbka",
2525
"Conarnar", "lucylq", "larryliu0820", "BujSet", "Gasoonjia", "Juntian777", "guangy10", "jackzhxng",
2626
"GregoryComer", "leafs1", "swolchok", "mergennachin", "tarun292", "byjlw", "jathu", "Jack-Khuu", "georgehong",
27-
"zhenyan-zhang-meta", "silverguo", "dbort", "jorgep31415", "huydhn", "mcremon-meta", "trivedivivek", "angelayi",
28-
"helunwencser", "hsharma35", "zhxchen17", "iseeyuan", "svekars", "nathanaelsee", "dulinriley", "jerryzh168",
27+
"zhenyan-zhang-meta", "silverguo", "harishs88ss", "AlannaBurke", "dbort", "huydhn", "mcremon-meta", "trivedivivek",
28+
"angelayi", "helunwencser", "hsharma35", "zhxchen17", "iseeyuan", "svekars", "nathanaelsee", "dulinriley", "jerryzh168",
2929
"cmodi-meta", "bigfootjon", "sxu", "ydwu4", "Riandy", "tugsbayasgalan", "bsoyluoglu", "yangw-dev", "YIWENX14",
3030
"namanahuja", "yushangdi", "limintang", "pianpwk", "viveknayakatmeta", "andreanicastro", "JakeStevens",
31-
"gmagogsfm", "zonglinpeng", "eigen-k", "derekxu", "salilsdesai", "skrtskrtfb", "pssrawat", "r-barnes", "pytorchbot",
32-
"pytorchmergebot", "pytorchupdatebot", "facebook-github-bot", "Erik-Lundell", "zingo", "AdrianLundell",
33-
"oscarandersson8218", "per", "Sebastian-Larsson", "SaoirseARM", "robell", "mansnils", "martinlsm", "freddan80",
34-
"YufengShi-dudu", "tom-arm", "perheld", "Jerry-Ge", "gggekov", "fumchin", "wwwind", "haowhsu-quic", "shewu-quic",
35-
"winskuo-quic", "chunit-quic", "DannyYuyang-quic", "chuntl", "cymbalrush", "DenisVieriu97", "billmguo",
36-
"StrycekSimon", "jirioc", "robert-kalmar", "skywall", "neuropilot-captain"
31+
"gmagogsfm", "zonglinpeng", "eigen-k", "derekxu", "salilsdesai", "skrtskrtfb", "pssrawat", "r-barnes",
32+
"kalpit-meta-1", "Will-MingLun-Li", "KapJI", "piyengar", "j-bahr", "BoyuanFeng", "fgasperij", "DariusHolmgren",
33+
"sammarden-meta", "kushrast", "meta-emilian", "Rittzz", "jeanschmidt", "copyrightly", "mikekgfb", "vmpuri",
34+
"zonglinpengmeta", "maggiemoss", "aorenste", "hoangminhle98", "Solumin", "meyering", "rchen152",
35+
"AishwaryaSivaraman", "migeed-z", "ebgraham", "Esteb37", "nausicaasnow", "Camyll", "ezyang", "huiyujie",
36+
"dltn", "cjhopman", "blackm00n", "agunapal", "SamGondelman", "Ninja91", "ivayloen", "DrJessop", "rodrigos01meta",
37+
"akrieger", "cmt0", "yiming0416", "ethansfng", "ThomasJannaud", "nirvanagth", "marcinkwiatkowski", "3l1",
38+
"omerjerk", "nitish2112", "yipjustin", "ejnguyen", "andrewor14", "phaiting", "mgiordy", "LeeOHzzZ", "adicatana",
39+
"Polyomino", "ezrilow", "navsud", "YifanShenSZ", "RdoubleA", "Olivia-liu", "Abhi-hpp", "Vysarat", "azad-meta",
40+
"pytorchbot", "pytorchmergebot", "pytorchupdatebot", "facebook-github-bot", "app/dependabot", "Erik-Lundell",
41+
"zingo", "AdrianLundell", "oscarandersson8218", "per", "Sebastian-Larsson", "SaoirseARM", "robell", "mansnils",
42+
"martinlsm", "freddan80", "YufengShi-dudu", "tom-arm", "perheld", "Jerry-Ge", "gggekov", "fumchin", "wwwind",
43+
"benkli01", "Tessil", "maddun01", "Michiel-Olieslagers", "armwaheed", "agrima1304", "emmakujala", "annietllnd",
44+
"haowhsu-quic", "shewu-quic", "winskuo-quic", "chunit-quic", "DannyYuyang-quic", "chuntl", "thchenqti",
45+
"jethroqti", "cymbalrush", "DenisVieriu97", "billmguo", "StrycekSimon", "jirioc", "robert-kalmar", "skywall",
46+
"MartinPavella", "roman-janik-nxp", "novak-vaclav ", "neuropilot-captain", "dijopaul", "cad-rlc", "cad-audio",
47+
"ynimmaga", "daniil-lyakhov", "emmanuel-ferdman", "cavusmustafa", "Jiseong-oh", "alexdean08"
3748
]);
3849
3950
async function addItem(contentId, type, number) {
@@ -80,11 +91,10 @@ jobs:
8091
owner,
8192
repo,
8293
state: 'open',
83-
draft: false,
8494
}
8595
);
8696
for (const pr of prs) {
87-
if (!excludedAuthors.has(pr.user.login)) {
97+
if (!pr.draft && !excludedAuthors.has(pr.user.login)) {
8898
await addItem(pr.node_id, 'pr', pr.number);
8999
}
90100
}

.github/workflows/trunk.yml

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -594,15 +594,22 @@ jobs:
594594
strategy:
595595
matrix:
596596
model: [qwen3_4b, phi_4_mini]
597+
runner: [linux.2xlarge]
598+
docker-image: [executorch-ubuntu-22.04-clang12]
599+
backend: [xnnpack]
597600
include:
598601
- model: qwen3_4b
599-
test_with_runner: true
602+
runner: linux.arm64.2xlarge
603+
docker-image: executorch-ubuntu-22.04-gcc11-aarch64
604+
backend: torchao
600605
- model: phi_4_mini
601-
test_with_runner: false
606+
runner: linux.arm64.2xlarge
607+
docker-image: executorch-ubuntu-22.04-gcc11-aarch64
608+
backend: torchao
602609
fail-fast: false
603610
with:
604-
runner: linux.2xlarge
605-
docker-image: ci-image:executorch-ubuntu-22.04-clang12
611+
runner: ${{ matrix.runner }}
612+
docker-image: ci-image:${{ matrix.docker-image }}
606613
submodules: 'recursive'
607614
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
608615
timeout: 900
@@ -612,9 +619,14 @@ jobs:
612619
conda activate "${CONDA_ENV}"
613620
614621
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh --build-tool cmake
622+
623+
if [[ "${{ matrix.backend }}" == "torchao" ]]; then
624+
BUILD_TORCHAO_EXPERIMENTAL=1 TORCHAO_BUILD_CPU_AARCH64=1 TORCHAO_BUILD_KLEIDIAI=1 TORCHAO_ENABLE_ARM_NEON_DOT=1 TORCHAO_PARALLEL_BACKEND=OPENMP pip install third-party/ao
625+
fi
626+
615627
pip install -U "huggingface_hub[cli]"
616628
617-
bash .ci/scripts/test_torchao_huggingface_checkpoints.sh ${{ matrix.model }} ${{ matrix.test_with_runner && '--test_with_runner' || '' }}
629+
bash .ci/scripts/test_torchao_huggingface_checkpoints.sh ${{ matrix.model }} ${{ matrix.model != 'phi_4_mini' && '--test_with_runner' || '' }} ${{ matrix.backend == 'torchao' && '--use_torchao_kernels' || '' }}
618630
619631
test-multimodal-macos:
620632
if: ${{ !github.event.pull_request.head.repo.fork }}

backends/arm/operators/op_add.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,18 @@ def define_node(
6464
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32_maxscale(
6565
tosa_graph, inputs, node, self.tosa_spec
6666
)
67+
elif inputs[0].dtype == ts.DType.INT16:
68+
rescaled_inputs, scale_back = (
69+
tqutils.insert_rescale_ops_int16_to_int32_maxscale(
70+
tosa_graph, inputs, node, self.tosa_spec
71+
)
72+
)
6773
else:
6874
# input[0].dtype == ts.DType.INT16 or ts.DType.INT32
6975
# Non quantized input, natively support by TOSA.ADD
7076
rescaled_inputs = inputs
7177

72-
if output.dtype == ts.DType.INT8:
78+
if output.dtype in [ts.DType.INT8, ts.DType.INT16]:
7379
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
7480
add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
7581
else:
@@ -99,6 +105,15 @@ def define_node(
99105
compute_rescale=False,
100106
tosa_spec=self.tosa_spec,
101107
) # type: ignore[possibly-undefined]
108+
elif output.dtype == ts.DType.INT16:
109+
tqutils.insert_rescale_op_to_int16(
110+
tosa_graph,
111+
add_output,
112+
scale_back,
113+
node,
114+
compute_rescale=False,
115+
tosa_spec=self.tosa_spec,
116+
) # type: ignore[possibly-undefined]
102117

103118

104119
@register_node_visitor

backends/arm/test/ops/test_add.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -276,9 +276,6 @@ def test_add_tensor_16a8w_tosa_INT(test_data: input_t1):
276276

277277
@common.parametrize("test_data", Add.test_data)
278278
@common.XfailIfNoCorstone300
279-
@pytest.mark.xfail(
280-
reason="Vela compilation fails with 'Invalid arguments' for int16 add operations. See: https://github.com/pytorch/executorch/issues/13730"
281-
)
282279
def test_add_tensor_16a8w_u55_INT16(test_data: input_t1):
283280
"""Test add operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
284281
per_channel_quantization = False
@@ -304,9 +301,6 @@ def test_add_tensor_16a8w_u55_INT16(test_data: input_t1):
304301

305302
@common.parametrize("test_data", Add.test_data)
306303
@common.XfailIfNoCorstone320
307-
@pytest.mark.xfail(
308-
reason="Vela compilation fails with 'Invalid arguments' for int16 add operations. See: https://github.com/pytorch/executorch/issues/13730"
309-
)
310304
def test_add_tensor_16a8w_u85_INT16(test_data: input_t1):
311305
"""Test add operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
312306
per_channel_quantization = False

backends/arm/test/ops/test_to_copy.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -192,20 +192,15 @@ def test_to_vgf_INT(test_data: Tuple):
192192
),
193193
}
194194

195-
redundant_xfails_FP = {
195+
redundant_xfails = {
196196
"rand_fp16_fp16": "FP16 is not supported",
197197
"rand_int8_int8": "Tracing graph with quantized input is not supported.",
198198
"rand_int16_int16": "Tracing graph with quantized input is not supported.",
199199
}
200200

201-
redundant_xfails_INT = {
202-
"rand_fp16_fp16": "FP16 is not supported",
203-
"rand_int8_int8": "Tracing graph with quantized input is not supported.",
204-
}
205-
206201

207202
@common.parametrize(
208-
"test_data", _TO_COPY_TEST_DATA_REDUNDANT_CAST, xfails=redundant_xfails_FP
203+
"test_data", _TO_COPY_TEST_DATA_REDUNDANT_CAST, xfails=redundant_xfails
209204
)
210205
def test_to_tosa_FP_REDUNDANT_CAST(test_data: Tuple):
211206
test_tensor, new_dtype = test_data()
@@ -220,7 +215,7 @@ def test_to_tosa_FP_REDUNDANT_CAST(test_data: Tuple):
220215

221216

222217
@common.parametrize(
223-
"test_data", _TO_COPY_TEST_DATA_REDUNDANT_CAST, xfails=redundant_xfails_INT
218+
"test_data", _TO_COPY_TEST_DATA_REDUNDANT_CAST, xfails=redundant_xfails
224219
)
225220
def test_to_tosa_INT_REDUNDANT_CAST(test_data: Tuple):
226221
test_tensor, new_dtype = test_data()

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def define_arm_tests():
2525
"ops/test_tanh.py",
2626
"ops/test_view.py",
2727
"ops/test_cos.py",
28+
"ops/test_to_copy.py",
2829
]
2930

3031
# Quantization

backends/arm/tosa/quant_utils.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,59 @@ def insert_rescale_ops_to_int32_maxscale(
7777
return [rescaled_lhs, rescaled_rhs], back_scale
7878

7979

80+
def insert_rescale_ops_int16_to_int32_maxscale(
81+
tosa_graph: Any, inputs: list[TosaArg], node: Node, tosa_spec=None
82+
) -> tuple[list[Any], float]:
83+
"""For ADD and SUB with int16 inputs, we rescale to int32 using a different common scale(2*max(left scale,right scale))
84+
compared to all the other cases. We multiply the left and right scales by 1<<12 giving us extra precision
85+
for the computation without overflowing.
86+
87+
Returns a list of the rescaled nodes and the scale factor used,
88+
needed by insert_rescale_op_to_int16.
89+
"""
90+
91+
if len(inputs) > 2:
92+
raise ValueError("More than two inputs not supported")
93+
94+
tensors = inputs.copy()
95+
# Reshape tensor according to TOSA dim order
96+
for tensor in tensors:
97+
dim_order = tensor.dim_order
98+
tensor.shape = [tensor.shape[i] for i in dim_order]
99+
100+
input_qparams = get_input_qparams(node)
101+
lhs_qparams, rhs_qparams = input_qparams.values()
102+
lhs_scale = lhs_qparams.get_scale_per_tensor()
103+
rhs_scale = rhs_qparams.get_scale_per_tensor()
104+
# Common scale for the two numbers
105+
max_scale_2x = 2 * max(lhs_scale, rhs_scale)
106+
SHIFT_INT16 = 12
107+
# We are adding two int16 numbers. If the zero point is non-null, the result will be in the range [-131070;131070], therefore we need 18 bits for the result.
108+
# We have a 32-bit accumulator, so we can shift to the left by 12 bits and not overflow. In reality, because we divide by the 2*max(lhs_scale,rhs_scale)
109+
# we are shifting to the left by 11.
110+
lhs_factor = (1 << SHIFT_INT16) * lhs_scale / max_scale_2x
111+
rhs_factor = (1 << SHIFT_INT16) * rhs_scale / max_scale_2x
112+
rescaled_lhs = build_rescale_to_int32(
113+
tosa_graph,
114+
tensors[0],
115+
lhs_qparams.get_zp_per_tensor(),
116+
lhs_factor,
117+
tosa_spec=tosa_spec,
118+
)
119+
rescaled_rhs = build_rescale_to_int32(
120+
tosa_graph,
121+
tensors[1],
122+
rhs_qparams.get_zp_per_tensor(),
123+
rhs_factor,
124+
tosa_spec=tosa_spec,
125+
)
126+
out_qparam = get_output_qparams(node)[0]
127+
out_scale = out_qparam.get_scale_per_tensor()
128+
back_scale = max_scale_2x / (out_scale * (1 << SHIFT_INT16))
129+
130+
return [rescaled_lhs, rescaled_rhs], back_scale
131+
132+
80133
def insert_rescale_ops_to_int32(
81134
tosa_graph: Any,
82135
inputs: list[TosaArg],

backends/cadence/aot/ref_implementations.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -933,6 +933,51 @@ def quantized_conv1d_nlc_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
933933
def quantized_conv1d_nlc_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
934934

935935

936+
@impl(m, "convolution")
937+
def convolution(
938+
input_tensor: torch.Tensor,
939+
weight: torch.Tensor,
940+
bias: torch.Tensor,
941+
stride: tuple[int, int],
942+
padding: tuple[int, int],
943+
dilation: tuple[int, int],
944+
groups: int,
945+
channel_last: bool = False,
946+
) -> torch.Tensor:
947+
conv_is_1d = len(input_tensor.shape) == 3
948+
if channel_last:
949+
if conv_is_1d:
950+
input_tensor = input_tensor.movedim(-1, 1).contiguous()
951+
if len(weight.shape) != 3:
952+
raise ValueError("Weight tensor must be 3D if input is 3D")
953+
weight = weight.movedim(-1, 1).contiguous()
954+
else:
955+
input_tensor = input_tensor.movedim(-1, -3)
956+
if len(weight.shape) != 4:
957+
raise ValueError("Weight tensor must be 4D if input is nd > 3")
958+
weight = torch.permute(weight, (0, -1, 1, 2)).contiguous()
959+
960+
_stride: tuple[int, int] | int = stride
961+
_padding: tuple[int, int] | int = padding
962+
_dilation: tuple[int, int] | int = dilation
963+
if conv_is_1d:
964+
conv = torch.nn.functional.conv1d
965+
_stride = stride[0]
966+
_padding = padding[0]
967+
_dilation = dilation[0]
968+
else:
969+
conv = torch.nn.functional.conv2d
970+
971+
conv_out = conv(input_tensor, weight, bias, _stride, _padding, _dilation, groups)
972+
if channel_last:
973+
if conv_is_1d:
974+
conv_out = conv_out.movedim(1, -1).contiguous()
975+
else:
976+
conv_out = conv_out.movedim(-3, -1).contiguous()
977+
978+
return conv_out
979+
980+
936981
def quantized_relu_common(
937982
X: torch.Tensor,
938983
X_zero_point: torch.Tensor | int,

0 commit comments

Comments
 (0)