Skip to content

Commit 40c77e4

Browse files
committed
Update on "[ET-VK] Allow aten.cat.default to handle any number of input tensors"
## Context Previously, I updated the implementation of `aten.cat.default` in D76305343 (#11508) since the original implementation had a bug. The new implementation only supported up to 3 input tensors, but several models require the need for up to 6 input tensors. This diff updates the capabilities of the `concat` op so that any arbitrary number of input tensors may be accepted. ## Changes * Update implementation of the concat shader to be able to be called repeatedly, allowing support for any number of input tensors. Differential Revision: [D79893084](https://our.internmc.facebook.com/intern/diff/D79893084/) [ghstack-poisoned]
2 parents 7cb0276 + 346cc21 commit 40c77e4

File tree

16 files changed

+366
-109
lines changed

16 files changed

+366
-109
lines changed

.github/workflows/trunk.yml

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ jobs:
6060
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
6161
strategy:
6262
matrix:
63-
model: [add]
63+
model: [add, softmax, mv2]
6464
fail-fast: false
6565
with:
6666
runner: linux.2xlarge
@@ -72,6 +72,16 @@ jobs:
7272
MODEL_NAME=${{ matrix.model }}
7373
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
7474
conda activate "${CONDA_ENV}"
75+
if [[ ${{ matrix.model}} == "add" ]]; then
76+
SIM_LIMIT_SEC=60
77+
elif [[ ${{ matrix.model}} == "softmax" ]]; then
78+
SIM_LIMIT_SEC=60
79+
elif [[ ${{ matrix.model}} == "mv2" ]]; then
80+
SIM_LIMIT_SEC=5000
81+
else
82+
echo "Failed unsupported model selection ${{ matrix.model }}"
83+
exit 1
84+
fi
7585
7686
source .ci/scripts/utils.sh
7787
source .ci/scripts/zephyr-utils.sh
@@ -118,24 +128,22 @@ jobs:
118128
-C mps3_board.uart0.out_file='sim.out' \
119129
-C cpu0.CFGITCMSZ=15 \
120130
-C cpu0.CFGDTCMSZ=15 \
121-
--simlimit 120
131+
--simlimit ${SIM_LIMIT_SEC}
122132
123133
# Disable exit on error
124134
set +e
125135
# Report failure if any of the ouptut verification checks fail
126-
# store 0 if found (failure), 1 if not (success)
127136
grep -qF "ERROR" sim.out
128-
exit_status=$?
137+
exit_status=$? #store 0 if found (failure), 1 if not (success)
129138
if [[ "$exit_status" -eq "0" ]]; then
130139
cat sim.out
131140
set -e
132141
exit 1
133142
fi
134143
135144
# Report fail if simulation does not complete successfully
136-
# store 0 if found (success), 1 if not (failure)
137145
grep -qF "SUCCESS: Program complete, exiting." sim.out
138-
exit_status=$?
146+
exit_status=$? #store 0 if found (success), 1 if not (failure)
139147
if [[ "$exit_status" -eq "1" ]]; then
140148
cat sim.out
141149
set -e

backends/apple/coreml/test/test_torch_ops.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,10 @@ def test_dequantize_affine_c8w_embedding_b4w_linear(self):
167167
et_prog = delegated_program.to_executorch()
168168
self._compare_outputs(et_prog, model, example_inputs)
169169

170+
@unittest.skipIf(
171+
not hasattr(torch.version, "git_version"),
172+
"Enable in fbcode once D79658061 lands",
173+
)
170174
def test_dequantize_codebook_linear(self):
171175
model, example_inputs = self._get_test_model()
172176
quantize_(
@@ -194,6 +198,10 @@ def test_dequantize_codebook_linear(self):
194198
et_prog = delegated_program.to_executorch()
195199
self._compare_outputs(et_prog, model, example_inputs)
196200

201+
@unittest.skipIf(
202+
not hasattr(torch.version, "git_version"),
203+
"Enable in fbcode once D79658061 lands",
204+
)
197205
def test_dequantize_codebook_embedding(self):
198206
model, example_inputs = self._get_test_model()
199207
quantize_(

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from .decompose_atanh_pass import DecomposeAtanhPass # noqa
3232
from .decompose_avg_pool2d import DecomposeAvgPool2d # noqa
3333
from .decompose_batch_norm_no_stats import DecomposeBatchNormNoStatsPass # noqa
34+
from .decompose_cosh_pass import DecomposeCoshPass # noqa
3435
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
3536
from .decompose_div_pass import DecomposeDivPass # noqa
3637
from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
DecomposeAtanPass,
3737
DecomposeAvgPool2d,
3838
DecomposeBatchNormNoStatsPass,
39+
DecomposeCoshPass,
3940
DecomposeCosineSimilarityPass,
4041
DecomposeDivPass,
4142
DecomposeEmbeddingPass,
@@ -167,6 +168,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
167168
self.add_pass(DecomposeAcoshPass())
168169
self.add_pass(DecomposeAsinPass())
169170
self.add_pass(DecomposeAsinhPass())
171+
self.add_pass(DecomposeCoshPass())
170172
self.add_pass(DecomposeSqrtPass())
171173
self.add_pass(DecomposeAtanPass())
172174
self.add_pass(DecomposeAtanhPass())
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from executorch.backends.arm._passes import ArmPass
7+
from executorch.exir.dialects._ops import ops as exir_ops
8+
9+
# For MI case
10+
edge_cosh = exir_ops.edge.aten.cosh.default
11+
12+
13+
class DecomposeCoshPass(ArmPass):
14+
"""
15+
This pass replaces the cosh operator with a sequence of TOSA-equivalent operations that
16+
compute the hyperbolic cosine using the formula:
17+
18+
cosh(x) = 0.5 * (e^x + e^(-x))
19+
20+
"""
21+
22+
def call_operator(self, op, args, kwargs, meta, updated=False):
23+
if op is not edge_cosh:
24+
return super().call_operator(op, args, kwargs, meta, updated)
25+
26+
x = args
27+
28+
exp_op, mul_op, neg_op, add_op = (
29+
exir_ops.edge.aten.exp.default,
30+
exir_ops.edge.aten.mul.Scalar,
31+
exir_ops.edge.aten.neg.default,
32+
exir_ops.edge.aten.add.Tensor,
33+
)
34+
35+
# exp1 = e^x
36+
exp1 = super().call_operator(exp_op, x, {}, meta, updated=True)
37+
38+
# exp2 = e^(⁻x)
39+
neg_x = super().call_operator(neg_op, x, {}, meta, updated=True)
40+
exp2 = super().call_operator(exp_op, (neg_x,), {}, meta, updated=True)
41+
42+
# numer = exp1 + exp2
43+
numer = super().call_operator(add_op, (exp1, exp2), {}, meta, updated=True)
44+
45+
# out = 0.5 * numer
46+
out = super().call_operator(mul_op, (numer, 0.5), {}, meta, updated=True)
47+
48+
return out

backends/arm/_passes/insert_table_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class TableOps:
5959
exir_ops.edge.aten.acosh.default: torch.acosh,
6060
exir_ops.edge.aten.asin.default: torch.asin,
6161
exir_ops.edge.aten.asinh.default: torch.asinh,
62+
exir_ops.edge.aten.cosh.default: torch.cosh,
6263
}
6364

6465
# Targets that must be treated explicitly

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ def is_node_supported(
257257
exir_ops.edge.aten.addmm.default,
258258
exir_ops.edge.aten.masked_fill.Scalar,
259259
exir_ops.edge.aten.asinh.default,
260+
exir_ops.edge.aten.cosh.default,
260261
]
261262

262263
return supported

backends/arm/process_node.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88
from typing import Any, cast, Dict
99

1010
import numpy as np
11+
import serializer.tosa_serializer as ts
1112
import torch
1213
import torch.fx
1314
from executorch.backends.arm.operators.node_visitor import NodeVisitor
1415
from executorch.backends.arm.tosa_mapping import TosaArg
15-
from executorch.backends.arm.tosa_specification import Tosa_1_00, TosaSpecification
16+
from executorch.backends.arm.tosa_specification import TosaSpecification
1617
from executorch.backends.arm.tosa_utils import getNodeArgs, tosa_shape
1718
from torch._export.utils import (
1819
get_buffer,
@@ -81,11 +82,6 @@ def process_inputs(
8182
"Is the original torch function supported?"
8283
) from e
8384

84-
if isinstance(tosa_spec, Tosa_1_00):
85-
import serializer.tosa_serializer as ts
86-
else:
87-
raise ValueError(f"Unsupported TOSA spec: {tosa_spec}")
88-
8985
input_shape = tosa_arg.shape
9086
input_dim_order = tosa_arg.dim_order
9187
tensor = ts.TosaSerializerTensor(

backends/arm/quantizer/quantization_annotator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@ def _match_pattern(
287287
torch.ops.aten.asin.default,
288288
torch.ops.aten.atanh.default,
289289
torch.ops.aten.asinh.default,
290+
torch.ops.aten.cosh.default,
290291
]
291292

292293
_one_to_one_shared_input_qspec = [
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import common
9+
import pytest
10+
11+
import torch
12+
13+
from executorch.backends.arm.test.tester.test_pipeline import (
14+
EthosU55PipelineINT,
15+
EthosU85PipelineINT,
16+
TosaPipelineFP,
17+
TosaPipelineINT,
18+
VgfPipeline,
19+
)
20+
21+
from torchvision import models, transforms
22+
23+
ic3 = models.inception_v3(weights=models.Inception_V3_Weights)
24+
ic3 = ic3.eval()
25+
26+
# Normalization values referenced from here:
27+
# https://docs.pytorch.org/vision/main/models/generated/torchvision.models.quantization.inception_v3.html
28+
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
29+
30+
model_inputs = (normalize(torch.rand(1, 3, 224, 224)),)
31+
input_t = Tuple[torch.Tensor]
32+
33+
34+
@pytest.mark.slow
35+
def test_ic3_tosa_FP():
36+
pipeline = TosaPipelineFP[input_t](
37+
ic3,
38+
model_inputs,
39+
aten_op=[],
40+
exir_op=[],
41+
use_to_edge_transform_and_lower=True,
42+
)
43+
pipeline.run()
44+
45+
46+
@pytest.mark.slow
47+
def test_ic3_tosa_BI():
48+
pipeline = TosaPipelineINT[input_t](
49+
ic3,
50+
model_inputs,
51+
aten_op=[],
52+
exir_op=[],
53+
use_to_edge_transform_and_lower=True,
54+
atol=0.5,
55+
qtol=1,
56+
)
57+
pipeline.run()
58+
59+
60+
@pytest.mark.slow
61+
@pytest.mark.skip(reason="Takes too long to run on CI")
62+
@common.XfailIfNoCorstone300
63+
def test_ic3_u55_BI():
64+
pipeline = EthosU55PipelineINT[input_t](
65+
ic3,
66+
model_inputs,
67+
aten_ops=[],
68+
exir_ops=[],
69+
run_on_fvp=True,
70+
use_to_edge_transform_and_lower=True,
71+
atol=0.5,
72+
qtol=1,
73+
)
74+
pipeline.run()
75+
76+
77+
@pytest.mark.slow
78+
@pytest.mark.skip(reason="Takes too long to run on CI")
79+
@common.XfailIfNoCorstone320
80+
def test_ic3_u85_BI():
81+
pipeline = EthosU85PipelineINT[input_t](
82+
ic3,
83+
model_inputs,
84+
aten_ops=[],
85+
exir_ops=[],
86+
run_on_fvp=True,
87+
use_to_edge_transform_and_lower=True,
88+
atol=0.5,
89+
qtol=1,
90+
)
91+
pipeline.run()
92+
93+
94+
@pytest.mark.slow
95+
@pytest.mark.skip(reason="Takes too long to run on CI")
96+
@common.SkipIfNoModelConverter
97+
def test_ic3_vgf_FP():
98+
pipeline = VgfPipeline[input_t](
99+
ic3,
100+
model_inputs,
101+
aten_op=[],
102+
exir_op=[],
103+
tosa_version="TOSA-1.0+FP",
104+
use_to_edge_transform_and_lower=True,
105+
)
106+
pipeline.run()
107+
108+
109+
@pytest.mark.slow
110+
@pytest.mark.skip(reason="Takes too long to run on CI")
111+
@common.SkipIfNoModelConverter
112+
def test_ic3_vgf_INT():
113+
pipeline = VgfPipeline[input_t](
114+
ic3,
115+
model_inputs,
116+
aten_op=[],
117+
exir_op=[],
118+
tosa_version="TOSA-1.0+INT",
119+
use_to_edge_transform_and_lower=True,
120+
)
121+
pipeline.run()

0 commit comments

Comments
 (0)