Skip to content

Commit 468f95d

Browse files
Arm backend: support list-of-tensors args in FuseConstantArgsPass (pytorch#13037)
FuseConstantArgsPass is changed so that it no longer unpacks lists or tuples in node.args and node.kwargs. Those sequences are now preserved when the operator is invoked. That means ops like aten::cat, which expect their first argument to be a `List[Tensor]`, still get exactly that. Previously, the pass would flatten the list into separate tensor arguments, leading to this runtime error: [WARNING 2025-07-24 11:37:13,749 fuse_constant_ops_pass.py:133] Failed to fuse constant op aten_cat_default due to exception: aten::cat() Expected a value of type 'List[Tensor]' for argument 'tensors' but instead found type 'Tensor'. Signed-off-by: Sebastian Larsson <[email protected]>
1 parent 00117cf commit 468f95d

File tree

2 files changed

+55
-16
lines changed

2 files changed

+55
-16
lines changed

backends/arm/_passes/fuse_constant_ops_pass.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import logging
77

88
import torch._export.utils
9+
import torch.fx
910
from executorch.backends.arm._passes.arm_pass_utils import (
1011
get_constant_placeholder_kind,
1112
get_first_fake_tensor,
@@ -50,22 +51,26 @@ def _fuse_nodes(self, node) -> bool:
5051
the operations already carried out on the data.
5152
"""
5253

53-
# Extract tensors and args from the node
54-
data_list = [
55-
get_param_tensor(self.exported_program, input_node)
56-
for input_node in node.all_input_nodes
57-
]
58-
59-
args = node.args[len(node.all_input_nodes) :]
60-
kwargs = node.kwargs
61-
62-
if "input_qparams" in node.meta and len(node.meta["input_qparams"]) > 0:
63-
for i in range(len(node.all_input_nodes)):
64-
q_params = node.meta["input_qparams"][i]
65-
data_list[i] = q_params.dequantize_value(data_list[i])
66-
67-
# Run the op on the extracted tensor
68-
data = node.target(*data_list, *args, **kwargs)
54+
input_nodes = list(node.all_input_nodes)
55+
qparams = node.meta.get("input_qparams", None)
56+
57+
def resolve_arg(arg):
58+
if isinstance(arg, torch.fx.Node) and arg in input_nodes:
59+
idx = input_nodes.index(arg)
60+
t = get_param_tensor(self.exported_program, arg)
61+
if qparams:
62+
t = qparams[idx].dequantize_value(t)
63+
return t
64+
if isinstance(arg, tuple):
65+
return tuple(resolve_arg(x) for x in arg)
66+
if isinstance(arg, list):
67+
return [resolve_arg(x) for x in arg]
68+
return arg
69+
70+
new_args = tuple(resolve_arg(a) for a in node.args)
71+
new_kwargs = {k: resolve_arg(v) for k, v in node.kwargs.items()}
72+
73+
data = node.target(*new_args, **new_kwargs)
6974

7075
# Only fuse if the tensor does not get bigger.
7176
if data.numel() > get_first_fake_tensor(node).numel():

backends/arm/test/passes/test_fuse_constant_ops_pass.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
1616

1717
input_t = Tuple[torch.Tensor] # Input x
18+
input_t2 = Tuple[torch.Tensor, torch.Tensor]
1819

1920

2021
class FuseParameter(torch.nn.Module):
@@ -86,12 +87,32 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
8687
return operator.add(sliced, x)
8788

8889

90+
class CatConst(torch.nn.Module):
91+
ops_before_pass = {
92+
"executorch_exir_dialects_edge__ops_aten_cat_default": 1,
93+
}
94+
ops_after_pass = {
95+
"executorch_exir_dialects_edge__ops_aten_cat_default": 1,
96+
}
97+
ops_not_after_pass = []
98+
99+
def __init__(self):
100+
super().__init__()
101+
102+
def forward(self, a, b):
103+
return torch.cat((a, b), dim=0)
104+
105+
89106
modules = {
90107
"fuse_parameter": FuseParameter(),
91108
"fuse_buffer": FuseBuffer(),
92109
"fuse_const_tensor": FuseLiftedTensor(),
93110
}
94111

112+
cat_module = {
113+
"fuse_cat": CatConst(),
114+
}
115+
95116

96117
@common.parametrize("module", modules)
97118
def test_fuse_const_ops_tosa_MI(module: torch.nn.Module):
@@ -118,3 +139,16 @@ def test_fuse_const_ops_tosa_BI(module: torch.nn.Module):
118139
passes_with_exported_program=[ComputeConstantOpsAOT, FuseConstantArgsPass],
119140
)
120141
pipeline.run()
142+
143+
144+
@common.parametrize("module", cat_module)
145+
def test_fuse_const_ops_tosa_BI_cat(module: torch.nn.Module):
146+
pipeline = PassPipeline[input_t2](
147+
module,
148+
(torch.rand(3), torch.rand(2)),
149+
quantize=True,
150+
ops_before_pass=module.ops_before_pass,
151+
ops_after_pass=module.ops_after_pass,
152+
passes_with_exported_program=[ComputeConstantOpsAOT, FuseConstantArgsPass],
153+
)
154+
pipeline.run()

0 commit comments

Comments
 (0)