Skip to content

Commit 24c593b

Browse files
mcr229facebook-github-bot
authored andcommitted
> 5 Cat Inputs and Support Empty Input Tensors (#7855)
Summary: PyTorch's cat.default operator can take in arbitrarily large number of inputs. This is because the input is a Tensor List. XNNPACK however supports largest of 5 input tensors at a time. It is common for > 5 input tensors to be concatenated together, so we should still support cat's with this operation. We can do so by adding a pass which decomposes the Cat operator. The first 5 operators can be concatenated together, and then we recursively inject more concatenate nodes to concatenate the result of the last operation with the next 4 input tensors. Another common design pattern is for Concatenates to start with an empty tensor and then concatenat tensors together into it. This results in some empty tensors as inputs to concatenate. Previously we don't partition inputs with empty tensors. I don't remember what the case was with empty tensors, but it seems to work now, so disabling that partitioner check for now. Perhaps CI will pick up an error if this is indeed erroronous Reviewed By: digantdesai Differential Revision: D68523312
1 parent 5cbfcdc commit 24c593b

File tree

6 files changed

+249
-51
lines changed

6 files changed

+249
-51
lines changed

backends/xnnpack/_passes/TARGETS

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,7 @@ oncall("executorch")
44

55
python_library(
66
name = "xnnpack_passes",
7-
srcs = [
8-
"__init__.py",
9-
"channels_last_tagged_reshape_pass.py",
10-
"conv1d_unsqueeze_pass.py",
11-
"convert_to_linear.py",
12-
"convert_to_sdpa.py",
13-
"convert_to_upsample_bilinear2d.py",
14-
"fuse_activation_pass.py",
15-
"fuse_batch_norm_with_conv.py",
16-
"prelu_reshape_pass.py",
17-
"remove_getitem_op.py",
18-
"tag_implicit_q_dq_pass.py",
19-
"xnnpack_pass.py",
20-
],
7+
srcs = native.glob(["*.py"]),
218
deps = [
229
"//caffe2:torch",
2310
"//executorch/backends/transforms:addmm_mm_to_linear",

backends/xnnpack/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from executorch.backends.xnnpack._passes.fuse_batch_norm_with_conv import (
2222
FuseBatchNormWithConvPass,
2323
)
24+
from executorch.backends.xnnpack._passes.decompose_cat import DecomposeConcatenate
2425
from executorch.backends.xnnpack._passes.prelu_reshape_pass import PReLUReshapePass
2526
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
2627
from executorch.backends.xnnpack._passes.tag_implicit_q_dq_pass import (
@@ -63,6 +64,7 @@ def __init__(
6364
ConstPropPass,
6465
FuseBatchNormWithConvPass,
6566
FuseActivationPass,
67+
DecomposeConcatenate,
6668
RemoveGetItemPass,
6769
Conv1dUnsqueezePass,
6870
PReLUReshapePass,
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
import torch
9+
from executorch.exir.dialects._ops import ops as exir_ops
10+
11+
from executorch.exir.pass_base import ExportPass, PassResult
12+
from executorch.backends.xnnpack.utils.quant_utils import (
13+
is_dequant,
14+
is_quant,
15+
)
16+
17+
logger = logging.getLogger(__name__)
18+
logger.setLevel(logging.WARNING)
19+
20+
class DecomposeConcatenate(ExportPass):
21+
"""
22+
XNNPACK's Concatenate operation only supports concatenation for <= 5 tensors
23+
at a time. As a result, to support concatenates with > 5 tensors, we can decompose
24+
concatenates into sequences of cats each with <= 5 tensors.
25+
26+
Example:
27+
Before Pass:
28+
cat: "f32" = torch.ops.aten.cat.default([t1, t2, t3, t4, t5, t6], 1);
29+
30+
After Pass:
31+
cat: "f32" = torch.ops.aten.cat.default([t1, t2, t3, t4, t5], 1);
32+
cat_1: "f32" = torch.ops.aten.cat.default([cat, t6], 1);
33+
"""
34+
35+
def call(self, graph_module: torch.fx.GraphModule):
36+
gm = graph_module
37+
for node in gm.graph.nodes:
38+
if (node.op == "call_function"
39+
and node.target.__name__ == "aten.cat.default"):
40+
concat_args = node.args
41+
nodes_to_concat = node.args[0]
42+
if len(nodes_to_concat) <= 5:
43+
continue
44+
45+
is_quantized = (all(is_dequant(node) for node in nodes_to_concat)
46+
and all(is_quant(node) for node in node.users.keys()))
47+
48+
# replace the cat args with the same args but only with the first 5 nodes
49+
new_concat_args = (nodes_to_concat[:5],) + concat_args[1:]
50+
node.args = new_concat_args
51+
52+
remainder_nodes_to_concat = nodes_to_concat[5:]
53+
with gm.graph.inserting_after(node):
54+
logger.debug(f"Decomposing cat node {node}")
55+
remainder_concat_node = gm.graph.create_node(
56+
"call_function",
57+
target=exir_ops.edge.aten.cat.default,
58+
args=([],), # we will replace this remainder_nodes later
59+
kwargs=node.kwargs,
60+
)
61+
node.replace_all_uses_with(remainder_concat_node)
62+
if is_quantized:
63+
# if quantized we need to enforce the q/dq pattern for the newly inserted
64+
# concat node
65+
q_params = nodes_to_concat[0].args[1:]
66+
q_kwargs = nodes_to_concat[0].kwargs
67+
# Quantizer enforces all the inputs and output to a concat node must share
68+
# the same qparams, this means the newly inserted q/dq pair must share the
69+
# same qparams as the first quantized input in the concat node.
70+
with gm.graph.inserting_after(node):
71+
logger.debug(f"Inserting Q/DQ pair for new cat node {remainder_concat_node}")
72+
q_node = gm.graph.create_node(
73+
"call_function",
74+
target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
75+
args=(node,) + q_params,
76+
kwargs=q_kwargs,
77+
)
78+
with gm.graph.inserting_after(q_node):
79+
dq_node = gm.graph.create_node(
80+
"call_function",
81+
target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
82+
args=(q_node,) + q_params,
83+
kwargs=q_kwargs,
84+
)
85+
remainder_concat_node.args = ([dq_node] + remainder_nodes_to_concat,) + node.args[1:]
86+
else:
87+
remainder_concat_node.args = ([node] + remainder_nodes_to_concat,) + node.args[1:]
88+
89+
gm.recompile()
90+
new_gm = super().call(gm).graph_module
91+
return PassResult(new_gm, True)

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,10 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
181181

182182
num_tensors = len(node.all_input_nodes)
183183

184-
if not (num_tensors >= 2 and num_tensors <= 5):
184+
if not (num_tensors >= 2):
185185
why(
186186
node,
187-
reason=f"only support concatenation of 2 - 5 tensors, got {num_tensors} tensors",
187+
reason=f"only support concatenation of > 2 tensors, got {num_tensors} tensors",
188188
)
189189
return False
190190

backends/xnnpack/test/ops/test_cat.py

Lines changed: 44 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,13 @@
1414

1515
class TestCat(unittest.TestCase):
1616
class Cat(torch.nn.Module):
17+
def __init__(self, dim=0):
18+
super().__init__()
19+
self.dim = dim
20+
1721
def forward(self, *args):
1822
xs = [*args]
19-
x = torch.cat(xs)
23+
x = torch.cat(xs, dim=self.dim)
2024
return x + x # Quantize by propagation.
2125

2226
def _test_cat(self, module, inputs, cat_num=1, quant=False, quant_ops=2):
@@ -27,7 +31,6 @@ def _test_cat(self, module, inputs, cat_num=1, quant=False, quant_ops=2):
2731
tester.quantize()
2832

2933
tester.export().check_count({"torch.ops.aten.cat": 1})
30-
tester.dump_artifact()
3134

3235
if quant:
3336
# Expect multiple quantize ops - one per input, cat, and add.
@@ -93,6 +96,29 @@ def test_fp16_cat4(self):
9396
)
9497
self._test_cat(self.Cat(), inputs)
9598

99+
def test_fp16_cat5(self):
100+
"""
101+
Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first.
102+
"""
103+
inputs = (
104+
torch.randn(1, 2, 3).to(torch.float16),
105+
torch.randn(3, 2, 3).to(torch.float16),
106+
torch.randn(2, 2, 3).to(torch.float16),
107+
torch.randn(5, 2, 3).to(torch.float16),
108+
torch.randn(5, 2, 3).to(torch.float16),
109+
)
110+
self._test_cat(self.Cat(), inputs)
111+
112+
def test_fp16_cat_gt_5(self):
113+
"""
114+
Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first.
115+
"""
116+
for num_inputs in range(6, 10):
117+
inputs = []
118+
for _ in range(num_inputs):
119+
inputs.append(torch.randn(1, 2, 3).to(torch.float16))
120+
self._test_cat(self.Cat(), tuple(inputs))
121+
96122
def test_fp32_cat2(self):
97123
inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3))
98124
self._test_cat(self.Cat(), inputs)
@@ -120,6 +146,13 @@ def test_fp32_cat5(self):
120146
)
121147
self._test_cat(self.Cat(), inputs)
122148

149+
def test_fp32_cat_gt_5(self):
150+
for num_inputs in range(6, 10):
151+
inputs = []
152+
for _ in range(num_inputs):
153+
inputs.append(torch.randn(1, 2, 3))
154+
self._test_cat(self.Cat(), tuple(inputs))
155+
123156
def test_qs8_cat2(self):
124157
inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3))
125158
self._test_cat(self.Cat(), inputs, cat_num=2, quant=True)
@@ -137,46 +170,22 @@ def test_qs8_cat4(self):
137170
)
138171
self._test_cat(self.Cat(), inputs, cat_num=4, quant=True)
139172

140-
def test_fp32_cat_unsupported(self):
141-
"""
142-
XNNPACK only supports concatenating up to 4 values, so it should not delegate here.
143-
"""
173+
def test_qs8_cat5(self):
144174
inputs = (
145175
torch.randn(1, 2, 3),
146176
torch.randn(3, 2, 3),
147177
torch.randn(2, 2, 3),
148178
torch.randn(5, 2, 3),
149-
torch.randn(1, 2, 3),
150-
torch.randn(2, 2, 3),
151-
)
152-
(
153-
Tester(self.Cat(), inputs)
154-
.export()
155-
.check_count({"torch.ops.aten.cat": 1})
156-
.to_edge_transform_and_lower()
157-
.check_count({"executorch_exir_dialects_edge__ops_aten_cat": 1})
158-
)
159-
160-
def test_fp32_cat_unsupported_legacy_mode(self):
161-
"""
162-
XNNPACK only supports concatenating up to 5 values, so it should not delegate here.
163-
"""
164-
inputs = (
165-
torch.randn(1, 2, 3),
166-
torch.randn(3, 2, 3),
167-
torch.randn(2, 2, 3),
168179
torch.randn(5, 2, 3),
169-
torch.randn(1, 2, 3),
170-
torch.randn(6, 2, 3),
171-
)
172-
(
173-
Tester(self.Cat(), inputs)
174-
.export()
175-
.check_count({"torch.ops.aten.cat": 1})
176-
.to_edge()
177-
.partition()
178-
.check_count({"executorch_exir_dialects_edge__ops_aten_cat": 1})
179180
)
181+
self._test_cat(self.Cat(), inputs, cat_num=5, quant=True)
182+
183+
def test_qs8_cat_gt_5(self):
184+
for num_inputs in range(6, 10):
185+
inputs = []
186+
for _ in range(num_inputs):
187+
inputs.append(torch.randn(1, 2, 3))
188+
self._test_cat(self.Cat(), tuple(inputs), cat_num=num_inputs, quant=True)
180189

181190
class CatNegativeDim(torch.nn.Module):
182191
def __init__(self):
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
import math
11+
from executorch.backends.xnnpack._passes.decompose_cat import DecomposeConcatenate
12+
from executorch.backends.xnnpack.test.tester import RunPasses, Tester
13+
14+
15+
class TestDecomposeCatPass(unittest.TestCase):
16+
PassStage = RunPasses([DecomposeConcatenate])
17+
cat_name = "executorch_exir_dialects_edge__ops_aten_cat_default"
18+
class Cat(torch.nn.Module):
19+
def forward(self, *args):
20+
xs = [*args]
21+
x = torch.cat(xs)
22+
return x + x # Quantize by propagation.
23+
24+
def test_cat_gt_5(self):
25+
inputs = [
26+
torch.randn(1, 2, 3),
27+
]
28+
for num_inputs in range(6, 10):
29+
inputs = []
30+
for _ in range(num_inputs):
31+
inputs.append(torch.randn(1, 2, 3))
32+
33+
num_cats = int(len(inputs) > 5)
34+
num_cats += math.ceil((len(inputs) - 5)/4)
35+
(
36+
Tester(self.Cat(), tuple(inputs))
37+
.export()
38+
.to_edge()
39+
.check_count({self.cat_name: 1})
40+
.run_passes(self.PassStage)
41+
.check_count({self.cat_name: num_cats})
42+
.run_method_and_compare_outputs()
43+
)
44+
45+
46+
def test_cat_gt_10(self):
47+
inputs = [
48+
torch.randn(1, 2, 3),
49+
]
50+
for num_inputs in [11, 16, 18]:
51+
inputs = []
52+
for _ in range(num_inputs):
53+
inputs.append(torch.randn(1, 2, 3))
54+
55+
num_cats = int(len(inputs) > 5)
56+
num_cats += math.ceil((len(inputs) - 5)/4)
57+
(
58+
Tester(self.Cat(), tuple(inputs))
59+
.export()
60+
.to_edge()
61+
.check_count({self.cat_name: 1})
62+
.run_passes(self.PassStage)
63+
.check_count({self.cat_name: num_cats})
64+
.run_method_and_compare_outputs()
65+
)
66+
67+
def test_qs8_cat_gt_5(self):
68+
inputs = [
69+
torch.randn(1, 2, 3),
70+
]
71+
for num_inputs in range(6, 10):
72+
inputs = []
73+
for _ in range(num_inputs):
74+
inputs.append(torch.randn(1, 2, 3))
75+
76+
num_cats = int(len(inputs) > 5)
77+
num_cats += math.ceil((len(inputs) - 5)/4)
78+
(
79+
Tester(self.Cat(), tuple(inputs))
80+
.quantize()
81+
.export()
82+
.to_edge()
83+
.check_count({self.cat_name: 1})
84+
.run_passes(self.PassStage)
85+
.check_count({self.cat_name: num_cats})
86+
.run_method_and_compare_outputs()
87+
)
88+
89+
90+
def test_cat_gt_10(self):
91+
inputs = [
92+
torch.randn(1, 2, 3),
93+
]
94+
for num_inputs in [11, 16, 18]:
95+
inputs = []
96+
for _ in range(num_inputs):
97+
inputs.append(torch.randn(1, 2, 3))
98+
99+
num_cats = int(len(inputs) > 5)
100+
num_cats += math.ceil((len(inputs) - 5)/4)
101+
(
102+
Tester(self.Cat(), tuple(inputs))
103+
.export()
104+
.to_edge()
105+
.check_count({self.cat_name: 1})
106+
.run_passes(self.PassStage)
107+
.check_count({self.cat_name: num_cats})
108+
.run_method_and_compare_outputs()
109+
)

0 commit comments

Comments
 (0)