Skip to content

Commit c4c187a

Browse files
committed
fix lints
1 parent 0879b16 commit c4c187a

File tree

4 files changed

+23
-25
lines changed

4 files changed

+23
-25
lines changed

backends/xnnpack/test/ops/test_cat.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,14 @@ def test_qs8_cat_gt_5(self):
186186
for _ in range(num_inputs):
187187
inputs.append(torch.randn(1, 2, 3))
188188
self._test_cat(self.Cat(), tuple(inputs), cat_num=num_inputs, quant=True)
189-
189+
190190
def test_qs8_cat_with_empty_tensor(self):
191-
inputs = (torch.randn(0, 2, 3), torch.randn(1, 2, 3), torch.randn(3, 2, 3), torch.randn(0, 2, 3))
191+
inputs = (
192+
torch.randn(0, 2, 3),
193+
torch.randn(1, 2, 3),
194+
torch.randn(3, 2, 3),
195+
torch.randn(0, 2, 3),
196+
)
192197
self._test_cat(self.Cat(), inputs, cat_num=4, quant=True)
193198

194199
class CatNegativeDim(torch.nn.Module):

exir/passes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@
4242
from executorch.exir.passes.memory_format_ops_pass import MemoryFormatOpsPass
4343
from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
4444
from executorch.exir.passes.normalize_transpose_pass import NormalizeTransposePass
45+
from executorch.exir.passes.prune_empty_tensors_pass import PruneEmptyTensorsPass
4546
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
4647
from executorch.exir.passes.remove_noop_pass import RemoveNoopPass, RemoveToCopyPass
47-
from executorch.exir.passes.prune_empty_tensors_pass import PruneEmptyTensorsPass
4848
from executorch.exir.passes.replace_aten_with_edge_pass import OpReplacePass
4949
from executorch.exir.passes.replace_broken_ops_with_function_ops_pass import (
5050
ReplaceBrokenOpsWithFunctionalOpsPass,

exir/passes/prune_empty_tensors_pass.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66

77
# pyre-strict
88

9-
from typing import List, Tuple
10-
119
import torch
1210
from executorch.exir.dialects._ops import ops as exir_ops
1311
from executorch.exir.pass_base import ExportPass, PassResult
@@ -17,6 +15,7 @@
1715
# Which means that if we remove the empty tensor as input to this op,
1816
# The result of the operation will stay the same
1917

18+
2019
class PruneEmptyTensorsPass(ExportPass):
2120
"""
2221
Removes Any empty tensors from the graph that can safely be removed
@@ -25,7 +24,9 @@ class PruneEmptyTensorsPass(ExportPass):
2524
- aten.cat.default
2625
"""
2726

28-
def remove_empty_tensors_from_cat(self, graph_module: GraphModule, cat_node: Node) -> None:
27+
def remove_empty_tensors_from_cat(
28+
self, graph_module: GraphModule, cat_node: Node
29+
) -> None:
2930
"""
3031
Removes empty tensors from the graph that are inputs to aten.cat.default
3132
"""
@@ -35,7 +36,7 @@ def remove_empty_tensors_from_cat(self, graph_module: GraphModule, cat_node: Nod
3536
input_arg_tensor = input_arg.meta["val"]
3637
if input_arg_tensor.numel() != 0:
3738
pruned_concat_list.append(input_arg)
38-
39+
3940
cat_node.args = (pruned_concat_list,) + cat_node.args[1:]
4041
if len(pruned_concat_list) == 0:
4142
# if all the inputs to the cat are empty tensors, then we can replace
@@ -50,13 +51,12 @@ def remove_empty_tensors_from_cat(self, graph_module: GraphModule, cat_node: Nod
5051
)
5152
full_like.meta = cat_node.meta
5253
cat_node.replace_all_uses_with(full_like)
53-
5454

5555
def call(self, graph_module: GraphModule) -> PassResult:
5656
for node in graph_module.graph.nodes:
5757
if node.op != "call_function":
5858
continue
59-
59+
6060
if node.target == torch.ops.aten.cat.default:
6161
self.remove_empty_tensors_from_cat(graph_module, node)
6262

exir/tests/test_prune_empty_tensors_pass.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,14 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import copy
87
import unittest
98

109
import torch
1110
import torch.nn as nn
1211
from executorch.exir import to_edge
1312
from executorch.exir.capture._config import ExecutorchBackendConfig
14-
from executorch.exir.passes import MemoryPlanningPass
1513
from executorch.exir.dialects._ops import ops as exir_ops
14+
from executorch.exir.passes import MemoryPlanningPass
1615

1716

1817
class TestCat(nn.Module):
@@ -21,7 +20,7 @@ def forward(self, x, y, z):
2120
return torch.cat([empty, x, empty, y, z, empty])
2221

2322
def get_example_inputs(self):
24-
return (torch.rand(5, 6),torch.rand(5, 6),torch.rand(5, 6))
23+
return (torch.rand(5, 6), torch.rand(5, 6), torch.rand(5, 6))
2524

2625

2726
class TestPruneEmptyTensors(unittest.TestCase):
@@ -46,10 +45,8 @@ def test_empty_tensor_removed_from_cat(self) -> None:
4645
for input_arg in node.all_input_nodes:
4746
tensor_val = input_arg.meta["val"]
4847
self.assertTrue(tensor_val.numel() != 0)
49-
50-
actual = etpm.exported_program().module()(
51-
*example_inputs
52-
)
48+
49+
actual = etpm.exported_program().module()(*example_inputs)
5350

5451
reference = model(*example_inputs)
5552

@@ -58,7 +55,7 @@ def test_empty_tensor_removed_from_cat(self) -> None:
5855
def test_cat_removed_all_empty(self) -> None:
5956
model = TestCat()
6057
model.eval()
61-
example_inputs = (torch.empty((0, 6)),torch.empty((0, 6)),torch.empty((0, 6)))
58+
example_inputs = (torch.empty((0, 6)), torch.empty((0, 6)), torch.empty((0, 6)))
6259
ep = torch.export.export(model, example_inputs, strict=True)
6360
etpm = to_edge(ep).to_executorch(
6461
config=ExecutorchBackendConfig(
@@ -69,15 +66,11 @@ def test_cat_removed_all_empty(self) -> None:
6966

7067
for node in etpm.exported_program().graph_module.graph.nodes:
7168
self.assertFalse(
72-
node.target in [
73-
exir_ops.edge.aten.cat.default,
74-
torch.ops.aten.cat.default
75-
]
69+
node.target
70+
in [exir_ops.edge.aten.cat.default, torch.ops.aten.cat.default]
7671
)
77-
78-
actual = etpm.exported_program().module()(
79-
*example_inputs
80-
)
72+
73+
actual = etpm.exported_program().module()(*example_inputs)
8174

8275
reference = model(*example_inputs)
8376

0 commit comments

Comments
 (0)