diff --git a/backends/xnnpack/test/ops/test_cat.py b/backends/xnnpack/test/ops/test_cat.py index 377bf62aa7d..dd551ea3fa7 100644 --- a/backends/xnnpack/test/ops/test_cat.py +++ b/backends/xnnpack/test/ops/test_cat.py @@ -187,6 +187,15 @@ def test_qs8_cat_gt_5(self): inputs.append(torch.randn(1, 2, 3)) self._test_cat(self.Cat(), tuple(inputs), cat_num=num_inputs, quant=True) + def test_qs8_cat_with_empty_tensor(self): + inputs = ( + torch.randn(0, 2, 3), + torch.randn(1, 2, 3), + torch.randn(3, 2, 3), + torch.randn(0, 2, 3), + ) + self._test_cat(self.Cat(), inputs, cat_num=4, quant=True) + class CatNegativeDim(torch.nn.Module): def __init__(self): super().__init__() diff --git a/exir/passes/TARGETS b/exir/passes/TARGETS index ba300f70328..c34d4acede2 100644 --- a/exir/passes/TARGETS +++ b/exir/passes/TARGETS @@ -17,6 +17,7 @@ python_library( ":memory_planning_pass", ":normalize_transpose_pass", ":prim_ops_py_registry", + ":prune_empty_tensor_pass", ":quant_fusion_pass", ":quantize_io_pass", ":remove_noop_pass", @@ -197,6 +198,18 @@ python_library( ], ) +python_library( + name = "prune_empty_tensor_pass", + srcs = [ + "prune_empty_tensors_pass.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + ], +) + python_library( name = "remove_mixed_type_operators", srcs = [ diff --git a/exir/passes/__init__.py b/exir/passes/__init__.py index 6006f2463db..03bec011937 100644 --- a/exir/passes/__init__.py +++ b/exir/passes/__init__.py @@ -42,6 +42,7 @@ from executorch.exir.passes.memory_format_ops_pass import MemoryFormatOpsPass from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass from executorch.exir.passes.normalize_transpose_pass import NormalizeTransposePass +from executorch.exir.passes.prune_empty_tensors_pass import PruneEmptyTensorsPass from executorch.exir.passes.quant_fusion_pass import QuantFusionPass from executorch.exir.passes.remove_noop_pass import RemoveNoopPass, RemoveToCopyPass from executorch.exir.passes.replace_aten_with_edge_pass import OpReplacePass @@ -486,6 +487,7 @@ def dead_code_elimination_pass(graph_module: torch.fx.GraphModule) -> PassResult ScalarToTensorPass(), SymToTensorPass(), RemoveNoopPass(), + PruneEmptyTensorsPass(), RemoveToCopyPass(), ] ).passes diff --git a/exir/passes/prune_empty_tensors_pass.py b/exir/passes/prune_empty_tensors_pass.py new file mode 100644 index 00000000000..e9addfadced --- /dev/null +++ b/exir/passes/prune_empty_tensors_pass.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import cast, List + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from torch.fx import GraphModule, Node + +# This is a list of ops that are No Ops if used with an empty tensor. +# Which means that if we remove the empty tensor as input to this op, +# The result of the operation will stay the same + + +class PruneEmptyTensorsPass(ExportPass): + """ + Removes Any empty tensors from the graph that can safely be removed + without affecting the results of the graph. Currently we remove empty + tensor operations from the following ops: + - aten.cat.default + """ + + def remove_empty_tensors_from_cat( + self, graph_module: GraphModule, cat_node: Node + ) -> None: + """ + Removes empty tensors from the graph that are inputs to aten.cat.default + """ + concat_list = cast(List[Node], cat_node.args[0]) + pruned_concat_list = [] + for input_arg in concat_list: + input_arg_tensor = input_arg.meta["val"] + if input_arg_tensor.numel() != 0: + pruned_concat_list.append(input_arg) + + cat_node.args = (pruned_concat_list,) + cat_node.args[1:] + if len(pruned_concat_list) == 0: + # if all the inputs to the cat are empty tensors, then we can replace + # this concat node with an aten full like + cat_tensor = cat_node.meta["val"] + with graph_module.graph.inserting_after(cat_node): + full_like = graph_module.graph.create_node( + "call_function", + target=exir_ops.edge.aten.full.default, + args=(tuple(cat_tensor.shape), 0), + kwargs={"dtype": cat_tensor.dtype}, + ) + full_like.meta = cat_node.meta + cat_node.replace_all_uses_with(full_like) + + def call(self, graph_module: GraphModule) -> PassResult: + for node in graph_module.graph.nodes: + if node.op != "call_function": + continue + + if node.target == torch.ops.aten.cat.default: + self.remove_empty_tensors_from_cat(graph_module, node) + + graph_module.graph.eliminate_dead_code() + graph_module.graph.lint() + + return PassResult(graph_module, True) diff --git a/exir/tests/TARGETS b/exir/tests/TARGETS index 13253b0edcd..650a77c6ef6 100644 --- a/exir/tests/TARGETS +++ b/exir/tests/TARGETS @@ -443,6 +443,20 @@ python_unittest( ], ) +python_unittest( + name = "test_prune_empty_tensors", + srcs = [ + "test_prune_empty_tensors_pass.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:lib", + "//executorch/exir:memory", + "//executorch/exir/capture:config", + "//executorch/exir/passes:lib", + ], +) + python_unittest( name = "warnings", srcs = [ diff --git a/exir/tests/test_prune_empty_tensors_pass.py b/exir/tests/test_prune_empty_tensors_pass.py new file mode 100644 index 00000000000..8945c4544b2 --- /dev/null +++ b/exir/tests/test_prune_empty_tensors_pass.py @@ -0,0 +1,77 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +import torch.nn as nn +from executorch.exir import to_edge +from executorch.exir.capture._config import ExecutorchBackendConfig +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.passes import MemoryPlanningPass + + +class TestCat(nn.Module): + def forward(self, x, y, z): + empty = torch.empty((0, 6)) + return torch.cat([empty, x, empty, y, z, empty]) + + def get_example_inputs(self): + return (torch.rand(5, 6), torch.rand(5, 6), torch.rand(5, 6)) + + +class TestPruneEmptyTensors(unittest.TestCase): + def test_empty_tensor_removed_from_cat(self) -> None: + model = TestCat() + model.eval() + example_inputs = model.get_example_inputs() + ep = torch.export.export(model, example_inputs, strict=True) + etpm = to_edge(ep).to_executorch( + config=ExecutorchBackendConfig( + remove_view_copy=False, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + ), + ) + + for node in etpm.exported_program().graph_module.graph.nodes: + if node.target in [ + exir_ops.edge.aten.cat.default, + torch.ops.aten.cat.default, + ]: + self.assertTrue(len(node.all_input_nodes) == 3) + for input_arg in node.all_input_nodes: + tensor_val = input_arg.meta["val"] + self.assertTrue(tensor_val.numel() != 0) + + actual = etpm.exported_program().module()(*example_inputs) + + reference = model(*example_inputs) + + self.assertTrue(torch.allclose(actual, reference)) + + def test_cat_removed_all_empty(self) -> None: + model = TestCat() + model.eval() + example_inputs = (torch.empty((0, 6)), torch.empty((0, 6)), torch.empty((0, 6))) + ep = torch.export.export(model, example_inputs, strict=True) + etpm = to_edge(ep).to_executorch( + config=ExecutorchBackendConfig( + remove_view_copy=False, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + ), + ) + + for node in etpm.exported_program().graph_module.graph.nodes: + self.assertFalse( + node.target + in [exir_ops.edge.aten.cat.default, torch.ops.aten.cat.default] + ) + + actual = etpm.exported_program().module()(*example_inputs) + + reference = model(*example_inputs) + + self.assertTrue(torch.allclose(actual, reference))