Skip to content

Commit 34aff36

Browse files
committed
fix lints
1 parent 0879b16 commit 34aff36

File tree

4 files changed

+23
-22
lines changed

4 files changed

+23
-22
lines changed

backends/xnnpack/test/ops/test_cat.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,14 @@ def test_qs8_cat_gt_5(self):
186186
for _ in range(num_inputs):
187187
inputs.append(torch.randn(1, 2, 3))
188188
self._test_cat(self.Cat(), tuple(inputs), cat_num=num_inputs, quant=True)
189-
189+
190190
def test_qs8_cat_with_empty_tensor(self):
191-
inputs = (torch.randn(0, 2, 3), torch.randn(1, 2, 3), torch.randn(3, 2, 3), torch.randn(0, 2, 3))
191+
inputs = (
192+
torch.randn(0, 2, 3),
193+
torch.randn(1, 2, 3),
194+
torch.randn(3, 2, 3),
195+
torch.randn(0, 2, 3),
196+
)
192197
self._test_cat(self.Cat(), inputs, cat_num=4, quant=True)
193198

194199
class CatNegativeDim(torch.nn.Module):

exir/passes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@
4242
from executorch.exir.passes.memory_format_ops_pass import MemoryFormatOpsPass
4343
from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
4444
from executorch.exir.passes.normalize_transpose_pass import NormalizeTransposePass
45+
from executorch.exir.passes.prune_empty_tensors_pass import PruneEmptyTensorsPass
4546
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
4647
from executorch.exir.passes.remove_noop_pass import RemoveNoopPass, RemoveToCopyPass
47-
from executorch.exir.passes.prune_empty_tensors_pass import PruneEmptyTensorsPass
4848
from executorch.exir.passes.replace_aten_with_edge_pass import OpReplacePass
4949
from executorch.exir.passes.replace_broken_ops_with_function_ops_pass import (
5050
ReplaceBrokenOpsWithFunctionalOpsPass,

exir/passes/prune_empty_tensors_pass.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# Which means that if we remove the empty tensor as input to this op,
1818
# The result of the operation will stay the same
1919

20+
2021
class PruneEmptyTensorsPass(ExportPass):
2122
"""
2223
Removes Any empty tensors from the graph that can safely be removed
@@ -25,7 +26,9 @@ class PruneEmptyTensorsPass(ExportPass):
2526
- aten.cat.default
2627
"""
2728

28-
def remove_empty_tensors_from_cat(self, graph_module: GraphModule, cat_node: Node) -> None:
29+
def remove_empty_tensors_from_cat(
30+
self, graph_module: GraphModule, cat_node: Node
31+
) -> None:
2932
"""
3033
Removes empty tensors from the graph that are inputs to aten.cat.default
3134
"""
@@ -35,7 +38,7 @@ def remove_empty_tensors_from_cat(self, graph_module: GraphModule, cat_node: Nod
3538
input_arg_tensor = input_arg.meta["val"]
3639
if input_arg_tensor.numel() != 0:
3740
pruned_concat_list.append(input_arg)
38-
41+
3942
cat_node.args = (pruned_concat_list,) + cat_node.args[1:]
4043
if len(pruned_concat_list) == 0:
4144
# if all the inputs to the cat are empty tensors, then we can replace
@@ -50,13 +53,12 @@ def remove_empty_tensors_from_cat(self, graph_module: GraphModule, cat_node: Nod
5053
)
5154
full_like.meta = cat_node.meta
5255
cat_node.replace_all_uses_with(full_like)
53-
5456

5557
def call(self, graph_module: GraphModule) -> PassResult:
5658
for node in graph_module.graph.nodes:
5759
if node.op != "call_function":
5860
continue
59-
61+
6062
if node.target == torch.ops.aten.cat.default:
6163
self.remove_empty_tensors_from_cat(graph_module, node)
6264

exir/tests/test_prune_empty_tensors_pass.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
import torch.nn as nn
1212
from executorch.exir import to_edge
1313
from executorch.exir.capture._config import ExecutorchBackendConfig
14-
from executorch.exir.passes import MemoryPlanningPass
1514
from executorch.exir.dialects._ops import ops as exir_ops
15+
from executorch.exir.passes import MemoryPlanningPass
1616

1717

1818
class TestCat(nn.Module):
@@ -21,7 +21,7 @@ def forward(self, x, y, z):
2121
return torch.cat([empty, x, empty, y, z, empty])
2222

2323
def get_example_inputs(self):
24-
return (torch.rand(5, 6),torch.rand(5, 6),torch.rand(5, 6))
24+
return (torch.rand(5, 6), torch.rand(5, 6), torch.rand(5, 6))
2525

2626

2727
class TestPruneEmptyTensors(unittest.TestCase):
@@ -46,10 +46,8 @@ def test_empty_tensor_removed_from_cat(self) -> None:
4646
for input_arg in node.all_input_nodes:
4747
tensor_val = input_arg.meta["val"]
4848
self.assertTrue(tensor_val.numel() != 0)
49-
50-
actual = etpm.exported_program().module()(
51-
*example_inputs
52-
)
49+
50+
actual = etpm.exported_program().module()(*example_inputs)
5351

5452
reference = model(*example_inputs)
5553

@@ -58,7 +56,7 @@ def test_empty_tensor_removed_from_cat(self) -> None:
5856
def test_cat_removed_all_empty(self) -> None:
5957
model = TestCat()
6058
model.eval()
61-
example_inputs = (torch.empty((0, 6)),torch.empty((0, 6)),torch.empty((0, 6)))
59+
example_inputs = (torch.empty((0, 6)), torch.empty((0, 6)), torch.empty((0, 6)))
6260
ep = torch.export.export(model, example_inputs, strict=True)
6361
etpm = to_edge(ep).to_executorch(
6462
config=ExecutorchBackendConfig(
@@ -69,15 +67,11 @@ def test_cat_removed_all_empty(self) -> None:
6967

7068
for node in etpm.exported_program().graph_module.graph.nodes:
7169
self.assertFalse(
72-
node.target in [
73-
exir_ops.edge.aten.cat.default,
74-
torch.ops.aten.cat.default
75-
]
70+
node.target
71+
in [exir_ops.edge.aten.cat.default, torch.ops.aten.cat.default]
7672
)
77-
78-
actual = etpm.exported_program().module()(
79-
*example_inputs
80-
)
73+
74+
actual = etpm.exported_program().module()(*example_inputs)
8175

8276
reference = model(*example_inputs)
8377

0 commit comments

Comments
 (0)