From 0879b164e1409db1b0293268bf4d84bceb7b9f3c Mon Sep 17 00:00:00 2001 From: Max Ren Date: Fri, 24 Jan 2025 14:18:06 -0800 Subject: [PATCH 1/3] Remove unused Empty Tensors from Edge Graph Summary: Following up on this post: https://fb.workplace.com/groups/pytorch.edge.users/permalink/1683417975861822/ It seems like empty tensors aren't necessary in some ops, like concatenation. We should remove these as inputs in the edge graphs to make it easier to deal with. Reviewed By: digantdesai Differential Revision: D68589336 --- backends/xnnpack/test/ops/test_cat.py | 4 + exir/passes/TARGETS | 13 ++++ exir/passes/__init__.py | 2 + exir/passes/prune_empty_tensors_pass.py | 66 ++++++++++++++++ exir/tests/TARGETS | 14 ++++ exir/tests/test_prune_empty_tensors_pass.py | 84 +++++++++++++++++++++ 6 files changed, 183 insertions(+) create mode 100644 exir/passes/prune_empty_tensors_pass.py create mode 100644 exir/tests/test_prune_empty_tensors_pass.py diff --git a/backends/xnnpack/test/ops/test_cat.py b/backends/xnnpack/test/ops/test_cat.py index 377bf62aa7d..78d972f5593 100644 --- a/backends/xnnpack/test/ops/test_cat.py +++ b/backends/xnnpack/test/ops/test_cat.py @@ -186,6 +186,10 @@ def test_qs8_cat_gt_5(self): for _ in range(num_inputs): inputs.append(torch.randn(1, 2, 3)) self._test_cat(self.Cat(), tuple(inputs), cat_num=num_inputs, quant=True) + + def test_qs8_cat_with_empty_tensor(self): + inputs = (torch.randn(0, 2, 3), torch.randn(1, 2, 3), torch.randn(3, 2, 3), torch.randn(0, 2, 3)) + self._test_cat(self.Cat(), inputs, cat_num=4, quant=True) class CatNegativeDim(torch.nn.Module): def __init__(self): diff --git a/exir/passes/TARGETS b/exir/passes/TARGETS index ba300f70328..c34d4acede2 100644 --- a/exir/passes/TARGETS +++ b/exir/passes/TARGETS @@ -17,6 +17,7 @@ python_library( ":memory_planning_pass", ":normalize_transpose_pass", ":prim_ops_py_registry", + ":prune_empty_tensor_pass", ":quant_fusion_pass", ":quantize_io_pass", ":remove_noop_pass", @@ -197,6 +198,18 @@ python_library( ], ) +python_library( + name = "prune_empty_tensor_pass", + srcs = [ + "prune_empty_tensors_pass.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + ], +) + python_library( name = "remove_mixed_type_operators", srcs = [ diff --git a/exir/passes/__init__.py b/exir/passes/__init__.py index 6006f2463db..321ddd41fd4 100644 --- a/exir/passes/__init__.py +++ b/exir/passes/__init__.py @@ -44,6 +44,7 @@ from executorch.exir.passes.normalize_transpose_pass import NormalizeTransposePass from executorch.exir.passes.quant_fusion_pass import QuantFusionPass from executorch.exir.passes.remove_noop_pass import RemoveNoopPass, RemoveToCopyPass +from executorch.exir.passes.prune_empty_tensors_pass import PruneEmptyTensorsPass from executorch.exir.passes.replace_aten_with_edge_pass import OpReplacePass from executorch.exir.passes.replace_broken_ops_with_function_ops_pass import ( ReplaceBrokenOpsWithFunctionalOpsPass, @@ -486,6 +487,7 @@ def dead_code_elimination_pass(graph_module: torch.fx.GraphModule) -> PassResult ScalarToTensorPass(), SymToTensorPass(), RemoveNoopPass(), + PruneEmptyTensorsPass(), RemoveToCopyPass(), ] ).passes diff --git a/exir/passes/prune_empty_tensors_pass.py b/exir/passes/prune_empty_tensors_pass.py new file mode 100644 index 00000000000..1c025f7472d --- /dev/null +++ b/exir/passes/prune_empty_tensors_pass.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import List, Tuple + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from torch.fx import GraphModule, Node + +# This is a list of ops that are No Ops if used with an empty tensor. +# Which means that if we remove the empty tensor as input to this op, +# The result of the operation will stay the same + +class PruneEmptyTensorsPass(ExportPass): + """ + Removes Any empty tensors from the graph that can safely be removed + without affecting the results of the graph. Currently we remove empty + tensor operations from the following ops: + - aten.cat.default + """ + + def remove_empty_tensors_from_cat(self, graph_module: GraphModule, cat_node: Node) -> None: + """ + Removes empty tensors from the graph that are inputs to aten.cat.default + """ + concat_list = cat_node.all_input_nodes + pruned_concat_list = [] + for input_arg in concat_list: + input_arg_tensor = input_arg.meta["val"] + if input_arg_tensor.numel() != 0: + pruned_concat_list.append(input_arg) + + cat_node.args = (pruned_concat_list,) + cat_node.args[1:] + if len(pruned_concat_list) == 0: + # if all the inputs to the cat are empty tensors, then we can replace + # this concat node with an aten full like + cat_tensor = cat_node.meta["val"] + with graph_module.graph.inserting_after(cat_node): + full_like = graph_module.graph.create_node( + "call_function", + target=exir_ops.edge.aten.full.default, + args=(tuple(cat_tensor.shape), 0), + kwargs={"dtype": cat_tensor.dtype}, + ) + full_like.meta = cat_node.meta + cat_node.replace_all_uses_with(full_like) + + + def call(self, graph_module: GraphModule) -> PassResult: + for node in graph_module.graph.nodes: + if node.op != "call_function": + continue + + if node.target == torch.ops.aten.cat.default: + self.remove_empty_tensors_from_cat(graph_module, node) + + graph_module.graph.eliminate_dead_code() + graph_module.graph.lint() + + return PassResult(graph_module, True) diff --git a/exir/tests/TARGETS b/exir/tests/TARGETS index 13253b0edcd..650a77c6ef6 100644 --- a/exir/tests/TARGETS +++ b/exir/tests/TARGETS @@ -443,6 +443,20 @@ python_unittest( ], ) +python_unittest( + name = "test_prune_empty_tensors", + srcs = [ + "test_prune_empty_tensors_pass.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:lib", + "//executorch/exir:memory", + "//executorch/exir/capture:config", + "//executorch/exir/passes:lib", + ], +) + python_unittest( name = "warnings", srcs = [ diff --git a/exir/tests/test_prune_empty_tensors_pass.py b/exir/tests/test_prune_empty_tensors_pass.py new file mode 100644 index 00000000000..9f5b3b0d52f --- /dev/null +++ b/exir/tests/test_prune_empty_tensors_pass.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import unittest + +import torch +import torch.nn as nn +from executorch.exir import to_edge +from executorch.exir.capture._config import ExecutorchBackendConfig +from executorch.exir.passes import MemoryPlanningPass +from executorch.exir.dialects._ops import ops as exir_ops + + +class TestCat(nn.Module): + def forward(self, x, y, z): + empty = torch.empty((0, 6)) + return torch.cat([empty, x, empty, y, z, empty]) + + def get_example_inputs(self): + return (torch.rand(5, 6),torch.rand(5, 6),torch.rand(5, 6)) + + +class TestPruneEmptyTensors(unittest.TestCase): + def test_empty_tensor_removed_from_cat(self) -> None: + model = TestCat() + model.eval() + example_inputs = model.get_example_inputs() + ep = torch.export.export(model, example_inputs, strict=True) + etpm = to_edge(ep).to_executorch( + config=ExecutorchBackendConfig( + remove_view_copy=False, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + ), + ) + + for node in etpm.exported_program().graph_module.graph.nodes: + if node.target in [ + exir_ops.edge.aten.cat.default, + torch.ops.aten.cat.default, + ]: + self.assertTrue(len(node.all_input_nodes) == 3) + for input_arg in node.all_input_nodes: + tensor_val = input_arg.meta["val"] + self.assertTrue(tensor_val.numel() != 0) + + actual = etpm.exported_program().module()( + *example_inputs + ) + + reference = model(*example_inputs) + + self.assertTrue(torch.allclose(actual, reference)) + + def test_cat_removed_all_empty(self) -> None: + model = TestCat() + model.eval() + example_inputs = (torch.empty((0, 6)),torch.empty((0, 6)),torch.empty((0, 6))) + ep = torch.export.export(model, example_inputs, strict=True) + etpm = to_edge(ep).to_executorch( + config=ExecutorchBackendConfig( + remove_view_copy=False, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + ), + ) + + for node in etpm.exported_program().graph_module.graph.nodes: + self.assertFalse( + node.target in [ + exir_ops.edge.aten.cat.default, + torch.ops.aten.cat.default + ] + ) + + actual = etpm.exported_program().module()( + *example_inputs + ) + + reference = model(*example_inputs) + + self.assertTrue(torch.allclose(actual, reference)) From c4c187a5bc8e8a18a402bfc88383e82d89db0dd4 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Fri, 24 Jan 2025 14:33:51 -0800 Subject: [PATCH 2/3] fix lints --- backends/xnnpack/test/ops/test_cat.py | 9 ++++++-- exir/passes/__init__.py | 2 +- exir/passes/prune_empty_tensors_pass.py | 12 +++++----- exir/tests/test_prune_empty_tensors_pass.py | 25 ++++++++------------- 4 files changed, 23 insertions(+), 25 deletions(-) diff --git a/backends/xnnpack/test/ops/test_cat.py b/backends/xnnpack/test/ops/test_cat.py index 78d972f5593..dd551ea3fa7 100644 --- a/backends/xnnpack/test/ops/test_cat.py +++ b/backends/xnnpack/test/ops/test_cat.py @@ -186,9 +186,14 @@ def test_qs8_cat_gt_5(self): for _ in range(num_inputs): inputs.append(torch.randn(1, 2, 3)) self._test_cat(self.Cat(), tuple(inputs), cat_num=num_inputs, quant=True) - + def test_qs8_cat_with_empty_tensor(self): - inputs = (torch.randn(0, 2, 3), torch.randn(1, 2, 3), torch.randn(3, 2, 3), torch.randn(0, 2, 3)) + inputs = ( + torch.randn(0, 2, 3), + torch.randn(1, 2, 3), + torch.randn(3, 2, 3), + torch.randn(0, 2, 3), + ) self._test_cat(self.Cat(), inputs, cat_num=4, quant=True) class CatNegativeDim(torch.nn.Module): diff --git a/exir/passes/__init__.py b/exir/passes/__init__.py index 321ddd41fd4..03bec011937 100644 --- a/exir/passes/__init__.py +++ b/exir/passes/__init__.py @@ -42,9 +42,9 @@ from executorch.exir.passes.memory_format_ops_pass import MemoryFormatOpsPass from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass from executorch.exir.passes.normalize_transpose_pass import NormalizeTransposePass +from executorch.exir.passes.prune_empty_tensors_pass import PruneEmptyTensorsPass from executorch.exir.passes.quant_fusion_pass import QuantFusionPass from executorch.exir.passes.remove_noop_pass import RemoveNoopPass, RemoveToCopyPass -from executorch.exir.passes.prune_empty_tensors_pass import PruneEmptyTensorsPass from executorch.exir.passes.replace_aten_with_edge_pass import OpReplacePass from executorch.exir.passes.replace_broken_ops_with_function_ops_pass import ( ReplaceBrokenOpsWithFunctionalOpsPass, diff --git a/exir/passes/prune_empty_tensors_pass.py b/exir/passes/prune_empty_tensors_pass.py index 1c025f7472d..686856d9f90 100644 --- a/exir/passes/prune_empty_tensors_pass.py +++ b/exir/passes/prune_empty_tensors_pass.py @@ -6,8 +6,6 @@ # pyre-strict -from typing import List, Tuple - import torch from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -17,6 +15,7 @@ # Which means that if we remove the empty tensor as input to this op, # The result of the operation will stay the same + class PruneEmptyTensorsPass(ExportPass): """ Removes Any empty tensors from the graph that can safely be removed @@ -25,7 +24,9 @@ class PruneEmptyTensorsPass(ExportPass): - aten.cat.default """ - def remove_empty_tensors_from_cat(self, graph_module: GraphModule, cat_node: Node) -> None: + def remove_empty_tensors_from_cat( + self, graph_module: GraphModule, cat_node: Node + ) -> None: """ Removes empty tensors from the graph that are inputs to aten.cat.default """ @@ -35,7 +36,7 @@ def remove_empty_tensors_from_cat(self, graph_module: GraphModule, cat_node: Nod input_arg_tensor = input_arg.meta["val"] if input_arg_tensor.numel() != 0: pruned_concat_list.append(input_arg) - + cat_node.args = (pruned_concat_list,) + cat_node.args[1:] if len(pruned_concat_list) == 0: # if all the inputs to the cat are empty tensors, then we can replace @@ -50,13 +51,12 @@ def remove_empty_tensors_from_cat(self, graph_module: GraphModule, cat_node: Nod ) full_like.meta = cat_node.meta cat_node.replace_all_uses_with(full_like) - def call(self, graph_module: GraphModule) -> PassResult: for node in graph_module.graph.nodes: if node.op != "call_function": continue - + if node.target == torch.ops.aten.cat.default: self.remove_empty_tensors_from_cat(graph_module, node) diff --git a/exir/tests/test_prune_empty_tensors_pass.py b/exir/tests/test_prune_empty_tensors_pass.py index 9f5b3b0d52f..8945c4544b2 100644 --- a/exir/tests/test_prune_empty_tensors_pass.py +++ b/exir/tests/test_prune_empty_tensors_pass.py @@ -4,15 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import copy import unittest import torch import torch.nn as nn from executorch.exir import to_edge from executorch.exir.capture._config import ExecutorchBackendConfig -from executorch.exir.passes import MemoryPlanningPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.passes import MemoryPlanningPass class TestCat(nn.Module): @@ -21,7 +20,7 @@ def forward(self, x, y, z): return torch.cat([empty, x, empty, y, z, empty]) def get_example_inputs(self): - return (torch.rand(5, 6),torch.rand(5, 6),torch.rand(5, 6)) + return (torch.rand(5, 6), torch.rand(5, 6), torch.rand(5, 6)) class TestPruneEmptyTensors(unittest.TestCase): @@ -46,10 +45,8 @@ def test_empty_tensor_removed_from_cat(self) -> None: for input_arg in node.all_input_nodes: tensor_val = input_arg.meta["val"] self.assertTrue(tensor_val.numel() != 0) - - actual = etpm.exported_program().module()( - *example_inputs - ) + + actual = etpm.exported_program().module()(*example_inputs) reference = model(*example_inputs) @@ -58,7 +55,7 @@ def test_empty_tensor_removed_from_cat(self) -> None: def test_cat_removed_all_empty(self) -> None: model = TestCat() model.eval() - example_inputs = (torch.empty((0, 6)),torch.empty((0, 6)),torch.empty((0, 6))) + example_inputs = (torch.empty((0, 6)), torch.empty((0, 6)), torch.empty((0, 6))) ep = torch.export.export(model, example_inputs, strict=True) etpm = to_edge(ep).to_executorch( config=ExecutorchBackendConfig( @@ -69,15 +66,11 @@ def test_cat_removed_all_empty(self) -> None: for node in etpm.exported_program().graph_module.graph.nodes: self.assertFalse( - node.target in [ - exir_ops.edge.aten.cat.default, - torch.ops.aten.cat.default - ] + node.target + in [exir_ops.edge.aten.cat.default, torch.ops.aten.cat.default] ) - - actual = etpm.exported_program().module()( - *example_inputs - ) + + actual = etpm.exported_program().module()(*example_inputs) reference = model(*example_inputs) From 9a6b75b44524a5c6b0fde8b9f7fcb9600dbf6f03 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Mon, 27 Jan 2025 11:42:09 -0800 Subject: [PATCH 3/3] fix issue with duplicate inputs --- exir/passes/prune_empty_tensors_pass.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/exir/passes/prune_empty_tensors_pass.py b/exir/passes/prune_empty_tensors_pass.py index 686856d9f90..e9addfadced 100644 --- a/exir/passes/prune_empty_tensors_pass.py +++ b/exir/passes/prune_empty_tensors_pass.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. # pyre-strict +from typing import cast, List import torch from executorch.exir.dialects._ops import ops as exir_ops @@ -30,7 +31,7 @@ def remove_empty_tensors_from_cat( """ Removes empty tensors from the graph that are inputs to aten.cat.default """ - concat_list = cat_node.all_input_nodes + concat_list = cast(List[Node], cat_node.args[0]) pruned_concat_list = [] for input_arg in concat_list: input_arg_tensor = input_arg.meta["val"]