Skip to content

Commit 0879b16

Browse files
mcr229facebook-github-bot
authored andcommitted
Remove unused Empty Tensors from Edge Graph
Summary: Following up on this post: https://fb.workplace.com/groups/pytorch.edge.users/permalink/1683417975861822/ It seems like empty tensors aren't necessary in some ops, like concatenation. We should remove these as inputs in the edge graphs to make it easier to deal with. Reviewed By: digantdesai Differential Revision: D68589336
1 parent b522084 commit 0879b16

File tree

6 files changed

+183
-0
lines changed

6 files changed

+183
-0
lines changed

backends/xnnpack/test/ops/test_cat.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,10 @@ def test_qs8_cat_gt_5(self):
186186
for _ in range(num_inputs):
187187
inputs.append(torch.randn(1, 2, 3))
188188
self._test_cat(self.Cat(), tuple(inputs), cat_num=num_inputs, quant=True)
189+
190+
def test_qs8_cat_with_empty_tensor(self):
191+
inputs = (torch.randn(0, 2, 3), torch.randn(1, 2, 3), torch.randn(3, 2, 3), torch.randn(0, 2, 3))
192+
self._test_cat(self.Cat(), inputs, cat_num=4, quant=True)
189193

190194
class CatNegativeDim(torch.nn.Module):
191195
def __init__(self):

exir/passes/TARGETS

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ python_library(
1717
":memory_planning_pass",
1818
":normalize_transpose_pass",
1919
":prim_ops_py_registry",
20+
":prune_empty_tensor_pass",
2021
":quant_fusion_pass",
2122
":quantize_io_pass",
2223
":remove_noop_pass",
@@ -197,6 +198,18 @@ python_library(
197198
],
198199
)
199200

201+
python_library(
202+
name = "prune_empty_tensor_pass",
203+
srcs = [
204+
"prune_empty_tensors_pass.py",
205+
],
206+
deps = [
207+
"//caffe2:torch",
208+
"//executorch/exir:pass_base",
209+
"//executorch/exir/dialects:lib",
210+
],
211+
)
212+
200213
python_library(
201214
name = "remove_mixed_type_operators",
202215
srcs = [

exir/passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from executorch.exir.passes.normalize_transpose_pass import NormalizeTransposePass
4545
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
4646
from executorch.exir.passes.remove_noop_pass import RemoveNoopPass, RemoveToCopyPass
47+
from executorch.exir.passes.prune_empty_tensors_pass import PruneEmptyTensorsPass
4748
from executorch.exir.passes.replace_aten_with_edge_pass import OpReplacePass
4849
from executorch.exir.passes.replace_broken_ops_with_function_ops_pass import (
4950
ReplaceBrokenOpsWithFunctionalOpsPass,
@@ -486,6 +487,7 @@ def dead_code_elimination_pass(graph_module: torch.fx.GraphModule) -> PassResult
486487
ScalarToTensorPass(),
487488
SymToTensorPass(),
488489
RemoveNoopPass(),
490+
PruneEmptyTensorsPass(),
489491
RemoveToCopyPass(),
490492
]
491493
).passes
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from typing import List, Tuple
10+
11+
import torch
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from executorch.exir.pass_base import ExportPass, PassResult
14+
from torch.fx import GraphModule, Node
15+
16+
# This is a list of ops that are No Ops if used with an empty tensor.
17+
# Which means that if we remove the empty tensor as input to this op,
18+
# The result of the operation will stay the same
19+
20+
class PruneEmptyTensorsPass(ExportPass):
21+
"""
22+
Removes Any empty tensors from the graph that can safely be removed
23+
without affecting the results of the graph. Currently we remove empty
24+
tensor operations from the following ops:
25+
- aten.cat.default
26+
"""
27+
28+
def remove_empty_tensors_from_cat(self, graph_module: GraphModule, cat_node: Node) -> None:
29+
"""
30+
Removes empty tensors from the graph that are inputs to aten.cat.default
31+
"""
32+
concat_list = cat_node.all_input_nodes
33+
pruned_concat_list = []
34+
for input_arg in concat_list:
35+
input_arg_tensor = input_arg.meta["val"]
36+
if input_arg_tensor.numel() != 0:
37+
pruned_concat_list.append(input_arg)
38+
39+
cat_node.args = (pruned_concat_list,) + cat_node.args[1:]
40+
if len(pruned_concat_list) == 0:
41+
# if all the inputs to the cat are empty tensors, then we can replace
42+
# this concat node with an aten full like
43+
cat_tensor = cat_node.meta["val"]
44+
with graph_module.graph.inserting_after(cat_node):
45+
full_like = graph_module.graph.create_node(
46+
"call_function",
47+
target=exir_ops.edge.aten.full.default,
48+
args=(tuple(cat_tensor.shape), 0),
49+
kwargs={"dtype": cat_tensor.dtype},
50+
)
51+
full_like.meta = cat_node.meta
52+
cat_node.replace_all_uses_with(full_like)
53+
54+
55+
def call(self, graph_module: GraphModule) -> PassResult:
56+
for node in graph_module.graph.nodes:
57+
if node.op != "call_function":
58+
continue
59+
60+
if node.target == torch.ops.aten.cat.default:
61+
self.remove_empty_tensors_from_cat(graph_module, node)
62+
63+
graph_module.graph.eliminate_dead_code()
64+
graph_module.graph.lint()
65+
66+
return PassResult(graph_module, True)

exir/tests/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,20 @@ python_unittest(
443443
],
444444
)
445445

446+
python_unittest(
447+
name = "test_prune_empty_tensors",
448+
srcs = [
449+
"test_prune_empty_tensors_pass.py",
450+
],
451+
deps = [
452+
"//caffe2:torch",
453+
"//executorch/exir:lib",
454+
"//executorch/exir:memory",
455+
"//executorch/exir/capture:config",
456+
"//executorch/exir/passes:lib",
457+
],
458+
)
459+
446460
python_unittest(
447461
name = "warnings",
448462
srcs = [
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import copy
8+
import unittest
9+
10+
import torch
11+
import torch.nn as nn
12+
from executorch.exir import to_edge
13+
from executorch.exir.capture._config import ExecutorchBackendConfig
14+
from executorch.exir.passes import MemoryPlanningPass
15+
from executorch.exir.dialects._ops import ops as exir_ops
16+
17+
18+
class TestCat(nn.Module):
19+
def forward(self, x, y, z):
20+
empty = torch.empty((0, 6))
21+
return torch.cat([empty, x, empty, y, z, empty])
22+
23+
def get_example_inputs(self):
24+
return (torch.rand(5, 6),torch.rand(5, 6),torch.rand(5, 6))
25+
26+
27+
class TestPruneEmptyTensors(unittest.TestCase):
28+
def test_empty_tensor_removed_from_cat(self) -> None:
29+
model = TestCat()
30+
model.eval()
31+
example_inputs = model.get_example_inputs()
32+
ep = torch.export.export(model, example_inputs, strict=True)
33+
etpm = to_edge(ep).to_executorch(
34+
config=ExecutorchBackendConfig(
35+
remove_view_copy=False,
36+
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
37+
),
38+
)
39+
40+
for node in etpm.exported_program().graph_module.graph.nodes:
41+
if node.target in [
42+
exir_ops.edge.aten.cat.default,
43+
torch.ops.aten.cat.default,
44+
]:
45+
self.assertTrue(len(node.all_input_nodes) == 3)
46+
for input_arg in node.all_input_nodes:
47+
tensor_val = input_arg.meta["val"]
48+
self.assertTrue(tensor_val.numel() != 0)
49+
50+
actual = etpm.exported_program().module()(
51+
*example_inputs
52+
)
53+
54+
reference = model(*example_inputs)
55+
56+
self.assertTrue(torch.allclose(actual, reference))
57+
58+
def test_cat_removed_all_empty(self) -> None:
59+
model = TestCat()
60+
model.eval()
61+
example_inputs = (torch.empty((0, 6)),torch.empty((0, 6)),torch.empty((0, 6)))
62+
ep = torch.export.export(model, example_inputs, strict=True)
63+
etpm = to_edge(ep).to_executorch(
64+
config=ExecutorchBackendConfig(
65+
remove_view_copy=False,
66+
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
67+
),
68+
)
69+
70+
for node in etpm.exported_program().graph_module.graph.nodes:
71+
self.assertFalse(
72+
node.target in [
73+
exir_ops.edge.aten.cat.default,
74+
torch.ops.aten.cat.default
75+
]
76+
)
77+
78+
actual = etpm.exported_program().module()(
79+
*example_inputs
80+
)
81+
82+
reference = model(*example_inputs)
83+
84+
self.assertTrue(torch.allclose(actual, reference))

0 commit comments

Comments
 (0)