Skip to content

Commit 009fed4

Browse files
committed
fix lints for decompose_cat
1 parent 24c593b commit 009fed4

File tree

3 files changed

+39
-32
lines changed

3 files changed

+39
-32
lines changed

backends/xnnpack/_passes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
from executorch.backends.xnnpack._passes.convert_to_upsample_bilinear2d import (
1818
ConvertToUpsampleBilinear2d,
1919
)
20+
from executorch.backends.xnnpack._passes.decompose_cat import DecomposeConcatenate
2021
from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass
2122
from executorch.backends.xnnpack._passes.fuse_batch_norm_with_conv import (
2223
FuseBatchNormWithConvPass,
2324
)
24-
from executorch.backends.xnnpack._passes.decompose_cat import DecomposeConcatenate
2525
from executorch.backends.xnnpack._passes.prelu_reshape_pass import PReLUReshapePass
2626
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
2727
from executorch.backends.xnnpack._passes.tag_implicit_q_dq_pass import (

backends/xnnpack/_passes/decompose_cat.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,17 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import logging
8+
89
import torch
10+
from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant
911
from executorch.exir.dialects._ops import ops as exir_ops
1012

1113
from executorch.exir.pass_base import ExportPass, PassResult
12-
from executorch.backends.xnnpack.utils.quant_utils import (
13-
is_dequant,
14-
is_quant,
15-
)
1614

1715
logger = logging.getLogger(__name__)
1816
logger.setLevel(logging.WARNING)
1917

18+
2019
class DecomposeConcatenate(ExportPass):
2120
"""
2221
XNNPACK's Concatenate operation only supports concatenation for <= 5 tensors
@@ -25,37 +24,40 @@ class DecomposeConcatenate(ExportPass):
2524
2625
Example:
2726
Before Pass:
28-
cat: "f32" = torch.ops.aten.cat.default([t1, t2, t3, t4, t5, t6], 1);
29-
27+
cat: "f32" = torch.ops.aten.cat.default([t1, t2, t3, t4, t5, t6], 1);
28+
3029
After Pass:
31-
cat: "f32" = torch.ops.aten.cat.default([t1, t2, t3, t4, t5], 1);
32-
cat_1: "f32" = torch.ops.aten.cat.default([cat, t6], 1);
30+
cat: "f32" = torch.ops.aten.cat.default([t1, t2, t3, t4, t5], 1);
31+
cat_1: "f32" = torch.ops.aten.cat.default([cat, t6], 1);
3332
"""
3433

3534
def call(self, graph_module: torch.fx.GraphModule):
3635
gm = graph_module
3736
for node in gm.graph.nodes:
38-
if (node.op == "call_function"
39-
and node.target.__name__ == "aten.cat.default"):
37+
if (
38+
node.op == "call_function"
39+
and node.target.__name__ == "aten.cat.default"
40+
):
4041
concat_args = node.args
4142
nodes_to_concat = node.args[0]
4243
if len(nodes_to_concat) <= 5:
4344
continue
44-
45-
is_quantized = (all(is_dequant(node) for node in nodes_to_concat)
46-
and all(is_quant(node) for node in node.users.keys()))
45+
46+
is_quantized = all(
47+
is_dequant(node) for node in nodes_to_concat
48+
) and all(is_quant(node) for node in node.users.keys())
4749

4850
# replace the cat args with the same args but only with the first 5 nodes
4951
new_concat_args = (nodes_to_concat[:5],) + concat_args[1:]
50-
node.args = new_concat_args
52+
node.args = new_concat_args
5153

5254
remainder_nodes_to_concat = nodes_to_concat[5:]
5355
with gm.graph.inserting_after(node):
5456
logger.debug(f"Decomposing cat node {node}")
5557
remainder_concat_node = gm.graph.create_node(
5658
"call_function",
5759
target=exir_ops.edge.aten.cat.default,
58-
args=([],), # we will replace this remainder_nodes later
60+
args=([],), # we will replace this remainder_nodes later
5961
kwargs=node.kwargs,
6062
)
6163
node.replace_all_uses_with(remainder_concat_node)
@@ -64,11 +66,13 @@ def call(self, graph_module: torch.fx.GraphModule):
6466
# concat node
6567
q_params = nodes_to_concat[0].args[1:]
6668
q_kwargs = nodes_to_concat[0].kwargs
67-
# Quantizer enforces all the inputs and output to a concat node must share
69+
# Quantizer enforces all the inputs and output to a concat node must share
6870
# the same qparams, this means the newly inserted q/dq pair must share the
6971
# same qparams as the first quantized input in the concat node.
7072
with gm.graph.inserting_after(node):
71-
logger.debug(f"Inserting Q/DQ pair for new cat node {remainder_concat_node}")
73+
logger.debug(
74+
f"Inserting Q/DQ pair for new cat node {remainder_concat_node}"
75+
)
7276
q_node = gm.graph.create_node(
7377
"call_function",
7478
target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
@@ -82,10 +86,14 @@ def call(self, graph_module: torch.fx.GraphModule):
8286
args=(q_node,) + q_params,
8387
kwargs=q_kwargs,
8488
)
85-
remainder_concat_node.args = ([dq_node] + remainder_nodes_to_concat,) + node.args[1:]
89+
remainder_concat_node.args = (
90+
[dq_node] + remainder_nodes_to_concat,
91+
) + node.args[1:]
8692
else:
87-
remainder_concat_node.args = ([node] + remainder_nodes_to_concat,) + node.args[1:]
88-
93+
remainder_concat_node.args = (
94+
[node] + remainder_nodes_to_concat,
95+
) + node.args[1:]
96+
8997
gm.recompile()
9098
new_gm = super().call(gm).graph_module
9199
return PassResult(new_gm, True)

backends/xnnpack/test/passes/test_decompose_cat_pass.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,18 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import math
78
import unittest
89

910
import torch
10-
import math
1111
from executorch.backends.xnnpack._passes.decompose_cat import DecomposeConcatenate
1212
from executorch.backends.xnnpack.test.tester import RunPasses, Tester
1313

1414

1515
class TestDecomposeCatPass(unittest.TestCase):
1616
PassStage = RunPasses([DecomposeConcatenate])
1717
cat_name = "executorch_exir_dialects_edge__ops_aten_cat_default"
18+
1819
class Cat(torch.nn.Module):
1920
def forward(self, *args):
2021
xs = [*args]
@@ -29,9 +30,9 @@ def test_cat_gt_5(self):
2930
inputs = []
3031
for _ in range(num_inputs):
3132
inputs.append(torch.randn(1, 2, 3))
32-
33+
3334
num_cats = int(len(inputs) > 5)
34-
num_cats += math.ceil((len(inputs) - 5)/4)
35+
num_cats += math.ceil((len(inputs) - 5) / 4)
3536
(
3637
Tester(self.Cat(), tuple(inputs))
3738
.export()
@@ -42,7 +43,6 @@ def test_cat_gt_5(self):
4243
.run_method_and_compare_outputs()
4344
)
4445

45-
4646
def test_cat_gt_10(self):
4747
inputs = [
4848
torch.randn(1, 2, 3),
@@ -51,9 +51,9 @@ def test_cat_gt_10(self):
5151
inputs = []
5252
for _ in range(num_inputs):
5353
inputs.append(torch.randn(1, 2, 3))
54-
54+
5555
num_cats = int(len(inputs) > 5)
56-
num_cats += math.ceil((len(inputs) - 5)/4)
56+
num_cats += math.ceil((len(inputs) - 5) / 4)
5757
(
5858
Tester(self.Cat(), tuple(inputs))
5959
.export()
@@ -72,9 +72,9 @@ def test_qs8_cat_gt_5(self):
7272
inputs = []
7373
for _ in range(num_inputs):
7474
inputs.append(torch.randn(1, 2, 3))
75-
75+
7676
num_cats = int(len(inputs) > 5)
77-
num_cats += math.ceil((len(inputs) - 5)/4)
77+
num_cats += math.ceil((len(inputs) - 5) / 4)
7878
(
7979
Tester(self.Cat(), tuple(inputs))
8080
.quantize()
@@ -86,7 +86,6 @@ def test_qs8_cat_gt_5(self):
8686
.run_method_and_compare_outputs()
8787
)
8888

89-
9089
def test_cat_gt_10(self):
9190
inputs = [
9291
torch.randn(1, 2, 3),
@@ -95,9 +94,9 @@ def test_cat_gt_10(self):
9594
inputs = []
9695
for _ in range(num_inputs):
9796
inputs.append(torch.randn(1, 2, 3))
98-
97+
9998
num_cats = int(len(inputs) > 5)
100-
num_cats += math.ceil((len(inputs) - 5)/4)
99+
num_cats += math.ceil((len(inputs) - 5) / 4)
101100
(
102101
Tester(self.Cat(), tuple(inputs))
103102
.export()

0 commit comments

Comments
 (0)