Skip to content

Commit 20a2133

Browse files
mcr229facebook-github-bot
authored andcommitted
Support Empty Input Tensors and > 5 Cat Inputs
Summary: PyTorch's cat.default operator can take in arbitrarily large number of inputs. This is because the input is a Tensor List. XNNPACK however supports largest of 5 input tensors at a time. It is common for > 5 input tensors to be concatenated together, so we should still support cat's with this operation. We can do so by adding a pass which decomposes the Cat operator. The first 5 operators can be concatenated together, and then we recursively inject more concatenate nodes to concatenate the result of the last operation with the next 4 input tensors. Another common design pattern is for Concatenates to start with an empty tensor and then concatenat tensors together into it. This results in some empty tensors as inputs to concatenate. Previously we don't partition inputs with empty tensors. I don't remember what the case was with empty tensors, but it seems to work now, so disabling that partitioner check for now. Perhaps CI will pick up an error if this is indeed erroronous Differential Revision: D68523312
1 parent ef2444f commit 20a2133

File tree

7 files changed

+171
-53
lines changed

7 files changed

+171
-53
lines changed

backends/xnnpack/_passes/TARGETS

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,7 @@ oncall("executorch")
44

55
python_library(
66
name = "xnnpack_passes",
7-
srcs = [
8-
"__init__.py",
9-
"channels_last_tagged_reshape_pass.py",
10-
"conv1d_unsqueeze_pass.py",
11-
"convert_to_linear.py",
12-
"convert_to_sdpa.py",
13-
"convert_to_upsample_bilinear2d.py",
14-
"fuse_activation_pass.py",
15-
"fuse_batch_norm_with_conv.py",
16-
"prelu_reshape_pass.py",
17-
"remove_getitem_op.py",
18-
"tag_implicit_q_dq_pass.py",
19-
"xnnpack_pass.py",
20-
],
7+
srcs = native.glob(["*.py"]),
218
deps = [
229
"//caffe2:torch",
2310
"//executorch/backends/transforms:addmm_mm_to_linear",

backends/xnnpack/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from executorch.backends.xnnpack._passes.fuse_batch_norm_with_conv import (
2222
FuseBatchNormWithConvPass,
2323
)
24+
from executorch.backends.xnnpack._passes.decompose_cat import DecomposeConcatenate
2425
from executorch.backends.xnnpack._passes.prelu_reshape_pass import PReLUReshapePass
2526
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
2627
from executorch.backends.xnnpack._passes.tag_implicit_q_dq_pass import (
@@ -63,6 +64,7 @@ def __init__(
6364
ConstPropPass,
6465
FuseBatchNormWithConvPass,
6566
FuseActivationPass,
67+
DecomposeConcatenate,
6668
RemoveGetItemPass,
6769
Conv1dUnsqueezePass,
6870
PReLUReshapePass,
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
10+
from executorch.exir.pass_base import ExportPass, PassResult
11+
from executorch.backends.xnnpack.utils.quant_utils import (
12+
is_dequant,
13+
is_quant,
14+
)
15+
16+
17+
class DecomposeConcatenate(ExportPass):
18+
"""
19+
XNNPACK's Concatenate operation only supports concatenation for <= 5 tensors
20+
at a time. As a result, to support concatenates with > 5 tensors, we can decompose
21+
concatenates into sequences of cats each with <= 5 tensors.
22+
23+
Example:
24+
Before Pass:
25+
cat: "f32" = torch.ops.aten.cat.default([t1, t2, t3, t4, t5, t6], 1);
26+
27+
After Pass:
28+
cat: "f32" = torch.ops.aten.cat.default([t1, t2, t3, t4, t5], 1);
29+
cat_1: "f32" = torch.ops.aten.cat.default([cat, t6], 1);
30+
"""
31+
32+
def call(self, graph_module: torch.fx.GraphModule):
33+
mdule = graph_module
34+
for node in mdule.graph.nodes:
35+
if (node.op == "call_function"
36+
and node.target.__name__ == "aten.cat.default"):
37+
concat_args = node.args
38+
nodes_to_concat = node.args[0]
39+
if len(nodes_to_concat) <= 5:
40+
continue
41+
42+
is_quantized = (all(is_dequant(node) for node in nodes_to_concat)
43+
and all(is_quant(node) for node in node.users.keys()))
44+
45+
# replace the cat args with the same args but only with the first 5 nodes
46+
new_concat_args = (nodes_to_concat[:5],) + concat_args[1:]
47+
node.args = new_concat_args
48+
49+
remainder_nodes_to_concat = nodes_to_concat[5:]
50+
with mdule.graph.inserting_after(node):
51+
remainder_concat_node = mdule.graph.create_node(
52+
"call_function",
53+
target=exir_ops.edge.aten.cat.default,
54+
args=([],), # we will replace this remainder_nodes later
55+
kwargs=node.kwargs,
56+
)
57+
node.replace_all_uses_with(remainder_concat_node)
58+
if is_quantized:
59+
q_params = nodes_to_concat[0].args[1:]
60+
with mdule.graph.inserting_after(node):
61+
q_node = mdule.graph.create_node(
62+
"call_function",
63+
target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
64+
args=(node,) + q_params,
65+
kwargs=node.kwargs,
66+
)
67+
with mdule.graph.inserting_after(q_node):
68+
dq_node = mdule.graph.create_node(
69+
"call_function",
70+
target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
71+
args=(q_node,) + q_params,
72+
kwargs=node.kwargs,
73+
)
74+
remainder_concat_node.args = ([dq_node] + remainder_nodes_to_concat,) + node.args[1:]
75+
else:
76+
remainder_concat_node.args = ([node] + remainder_nodes_to_concat,) + node.args[1:]
77+
78+
mdule.recompile()
79+
mdul = super().call(mdule).graph_module
80+
return PassResult(mdul, True)

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,10 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
181181

182182
num_tensors = len(node.all_input_nodes)
183183

184-
if not (num_tensors >= 2 and num_tensors <= 5):
184+
if not (num_tensors >= 2):
185185
why(
186186
node,
187-
reason=f"only support concatenation of 2 - 5 tensors, got {num_tensors} tensors",
187+
reason=f"only support concatenation of > 2 tensors, got {num_tensors} tensors",
188188
)
189189
return False
190190

backends/xnnpack/partition/config/xnnpack_config.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,10 +165,6 @@ def _check_inputs_are_valid_dtypes(self, node, valid_dtypes):
165165
if not isinstance(arg_val, torch.Tensor):
166166
return False
167167

168-
# XNNPACK does not support empty tensors
169-
if arg_val.numel() == 0:
170-
return False
171-
172168
if arg_val.dtype not in valid_dtypes:
173169
return False
174170

backends/xnnpack/test/ops/test_cat.py

Lines changed: 43 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,29 @@ def test_fp16_cat4(self):
9393
)
9494
self._test_cat(self.Cat(), inputs)
9595

96+
def test_fp16_cat5(self):
97+
"""
98+
Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first.
99+
"""
100+
inputs = (
101+
torch.randn(1, 2, 3).to(torch.float16),
102+
torch.randn(3, 2, 3).to(torch.float16),
103+
torch.randn(2, 2, 3).to(torch.float16),
104+
torch.randn(5, 2, 3).to(torch.float16),
105+
torch.randn(5, 2, 3).to(torch.float16),
106+
)
107+
self._test_cat(self.Cat(), inputs)
108+
109+
def test_fp16_cat_gt_5(self):
110+
"""
111+
Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first.
112+
"""
113+
for num_inputs in range(6, 10):
114+
inputs = []
115+
for _ in range(num_inputs):
116+
inputs.append(torch.randn(1, 2, 3).to(torch.float16))
117+
self._test_cat(self.Cat(), tuple(inputs))
118+
96119
def test_fp32_cat2(self):
97120
inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3))
98121
self._test_cat(self.Cat(), inputs)
@@ -120,6 +143,13 @@ def test_fp32_cat5(self):
120143
)
121144
self._test_cat(self.Cat(), inputs)
122145

146+
def test_fp32_cat_gt_5(self):
147+
for num_inputs in range(6, 10):
148+
inputs = []
149+
for _ in range(num_inputs):
150+
inputs.append(torch.randn(1, 2, 3))
151+
self._test_cat(self.Cat(), tuple(inputs))
152+
123153
def test_qs8_cat2(self):
124154
inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3))
125155
self._test_cat(self.Cat(), inputs, cat_num=2, quant=True)
@@ -137,46 +167,26 @@ def test_qs8_cat4(self):
137167
)
138168
self._test_cat(self.Cat(), inputs, cat_num=4, quant=True)
139169

140-
def test_fp32_cat_unsupported(self):
141-
"""
142-
XNNPACK only supports concatenating up to 4 values, so it should not delegate here.
143-
"""
170+
def test_qs8_cat5(self):
144171
inputs = (
145172
torch.randn(1, 2, 3),
146173
torch.randn(3, 2, 3),
147174
torch.randn(2, 2, 3),
148175
torch.randn(5, 2, 3),
149-
torch.randn(1, 2, 3),
150-
torch.randn(2, 2, 3),
151-
)
152-
(
153-
Tester(self.Cat(), inputs)
154-
.export()
155-
.check_count({"torch.ops.aten.cat": 1})
156-
.to_edge_transform_and_lower()
157-
.check_count({"executorch_exir_dialects_edge__ops_aten_cat": 1})
158-
)
159-
160-
def test_fp32_cat_unsupported_legacy_mode(self):
161-
"""
162-
XNNPACK only supports concatenating up to 5 values, so it should not delegate here.
163-
"""
164-
inputs = (
165-
torch.randn(1, 2, 3),
166-
torch.randn(3, 2, 3),
167-
torch.randn(2, 2, 3),
168176
torch.randn(5, 2, 3),
169-
torch.randn(1, 2, 3),
170-
torch.randn(6, 2, 3),
171-
)
172-
(
173-
Tester(self.Cat(), inputs)
174-
.export()
175-
.check_count({"torch.ops.aten.cat": 1})
176-
.to_edge()
177-
.partition()
178-
.check_count({"executorch_exir_dialects_edge__ops_aten_cat": 1})
179177
)
178+
self._test_cat(self.Cat(), inputs, cat_num=5, quant=True)
179+
180+
def test_qs8_cat_gt_5(self):
181+
for num_inputs in range(6, 10):
182+
inputs = []
183+
for _ in range(num_inputs):
184+
inputs.append(torch.randn(1, 2, 3))
185+
self._test_cat(self.Cat(), tuple(inputs), cat_num=num_inputs, quant=True)
186+
187+
def test_qs8_cat_with_empty_tensor(self):
188+
inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3), torch.randn(0, 2, 3))
189+
self._test_cat(self.Cat(), inputs, cat_num=3, quant=True)
180190

181191
class CatNegativeDim(torch.nn.Module):
182192
def __init__(self):
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
import math
11+
from executorch.backends.xnnpack._passes.decompose_cat import DecomposeConcatenate
12+
from executorch.backends.xnnpack.test.tester import RunPasses, Tester
13+
14+
15+
class TestDecomposeCatPass(unittest.TestCase):
16+
PassStage = RunPasses([DecomposeConcatenate])
17+
cat_name = "executorch_exir_dialects_edge__ops_aten_cat_default"
18+
class Cat(torch.nn.Module):
19+
def forward(self, *args):
20+
xs = [*args]
21+
x = torch.cat(xs)
22+
return x + x # Quantize by propagation.
23+
24+
def test_cat_gt_5(self):
25+
inputs = [
26+
torch.randn(1, 2, 3),
27+
]
28+
for num_inputs in range(6, 10):
29+
inputs = []
30+
for _ in range(num_inputs):
31+
inputs.append(torch.randn(1, 2, 3))
32+
33+
num_cats = int(len(inputs) > 5)
34+
num_cats += math.ceil((len(inputs) - 5)/4)
35+
(
36+
Tester(self.Cat(), tuple(inputs))
37+
.export()
38+
.to_edge()
39+
.check_count({self.cat_name: 1})
40+
.run_passes(self.PassStage)
41+
.check_count({self.cat_name: num_cats})
42+
.run_method_and_compare_outputs()
43+
)

0 commit comments

Comments
 (0)