Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions backends/xnnpack/test/ops/test_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,15 @@ def test_qs8_cat_gt_5(self):
inputs.append(torch.randn(1, 2, 3))
self._test_cat(self.Cat(), tuple(inputs), cat_num=num_inputs, quant=True)

def test_qs8_cat_with_empty_tensor(self):
inputs = (
torch.randn(0, 2, 3),
torch.randn(1, 2, 3),
torch.randn(3, 2, 3),
torch.randn(0, 2, 3),
)
self._test_cat(self.Cat(), inputs, cat_num=4, quant=True)

class CatNegativeDim(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
13 changes: 13 additions & 0 deletions exir/passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ python_library(
":memory_planning_pass",
":normalize_transpose_pass",
":prim_ops_py_registry",
":prune_empty_tensor_pass",
":quant_fusion_pass",
":quantize_io_pass",
":remove_noop_pass",
Expand Down Expand Up @@ -197,6 +198,18 @@ python_library(
],
)

python_library(
name = "prune_empty_tensor_pass",
srcs = [
"prune_empty_tensors_pass.py",
],
deps = [
"//caffe2:torch",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
],
)

python_library(
name = "remove_mixed_type_operators",
srcs = [
Expand Down
2 changes: 2 additions & 0 deletions exir/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from executorch.exir.passes.memory_format_ops_pass import MemoryFormatOpsPass
from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
from executorch.exir.passes.normalize_transpose_pass import NormalizeTransposePass
from executorch.exir.passes.prune_empty_tensors_pass import PruneEmptyTensorsPass
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
from executorch.exir.passes.remove_noop_pass import RemoveNoopPass, RemoveToCopyPass
from executorch.exir.passes.replace_aten_with_edge_pass import OpReplacePass
Expand Down Expand Up @@ -486,6 +487,7 @@ def dead_code_elimination_pass(graph_module: torch.fx.GraphModule) -> PassResult
ScalarToTensorPass(),
SymToTensorPass(),
RemoveNoopPass(),
PruneEmptyTensorsPass(),
RemoveToCopyPass(),
]
).passes
Expand Down
67 changes: 67 additions & 0 deletions exir/passes/prune_empty_tensors_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict
from typing import cast, List

import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx import GraphModule, Node

# This is a list of ops that are No Ops if used with an empty tensor.
# Which means that if we remove the empty tensor as input to this op,
# The result of the operation will stay the same


class PruneEmptyTensorsPass(ExportPass):
"""
Removes Any empty tensors from the graph that can safely be removed
without affecting the results of the graph. Currently we remove empty
tensor operations from the following ops:
- aten.cat.default
"""

def remove_empty_tensors_from_cat(
self, graph_module: GraphModule, cat_node: Node
) -> None:
"""
Removes empty tensors from the graph that are inputs to aten.cat.default
"""
concat_list = cast(List[Node], cat_node.args[0])
pruned_concat_list = []
for input_arg in concat_list:
input_arg_tensor = input_arg.meta["val"]
if input_arg_tensor.numel() != 0:
pruned_concat_list.append(input_arg)

cat_node.args = (pruned_concat_list,) + cat_node.args[1:]
if len(pruned_concat_list) == 0:
# if all the inputs to the cat are empty tensors, then we can replace
# this concat node with an aten full like
cat_tensor = cat_node.meta["val"]
with graph_module.graph.inserting_after(cat_node):
full_like = graph_module.graph.create_node(
"call_function",
target=exir_ops.edge.aten.full.default,
args=(tuple(cat_tensor.shape), 0),
kwargs={"dtype": cat_tensor.dtype},
)
full_like.meta = cat_node.meta
cat_node.replace_all_uses_with(full_like)

def call(self, graph_module: GraphModule) -> PassResult:
for node in graph_module.graph.nodes:
if node.op != "call_function":
continue

if node.target == torch.ops.aten.cat.default:
self.remove_empty_tensors_from_cat(graph_module, node)

graph_module.graph.eliminate_dead_code()
graph_module.graph.lint()

return PassResult(graph_module, True)
14 changes: 14 additions & 0 deletions exir/tests/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,20 @@ python_unittest(
],
)

python_unittest(
name = "test_prune_empty_tensors",
srcs = [
"test_prune_empty_tensors_pass.py",
],
deps = [
"//caffe2:torch",
"//executorch/exir:lib",
"//executorch/exir:memory",
"//executorch/exir/capture:config",
"//executorch/exir/passes:lib",
],
)

python_unittest(
name = "warnings",
srcs = [
Expand Down
77 changes: 77 additions & 0 deletions exir/tests/test_prune_empty_tensors_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
import torch.nn as nn
from executorch.exir import to_edge
from executorch.exir.capture._config import ExecutorchBackendConfig
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.passes import MemoryPlanningPass


class TestCat(nn.Module):
def forward(self, x, y, z):
empty = torch.empty((0, 6))
return torch.cat([empty, x, empty, y, z, empty])

def get_example_inputs(self):
return (torch.rand(5, 6), torch.rand(5, 6), torch.rand(5, 6))


class TestPruneEmptyTensors(unittest.TestCase):
def test_empty_tensor_removed_from_cat(self) -> None:
model = TestCat()
model.eval()
example_inputs = model.get_example_inputs()
ep = torch.export.export(model, example_inputs, strict=True)
etpm = to_edge(ep).to_executorch(
config=ExecutorchBackendConfig(
remove_view_copy=False,
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
),
)

for node in etpm.exported_program().graph_module.graph.nodes:
if node.target in [
exir_ops.edge.aten.cat.default,
torch.ops.aten.cat.default,
]:
self.assertTrue(len(node.all_input_nodes) == 3)
for input_arg in node.all_input_nodes:
tensor_val = input_arg.meta["val"]
self.assertTrue(tensor_val.numel() != 0)

actual = etpm.exported_program().module()(*example_inputs)

reference = model(*example_inputs)

self.assertTrue(torch.allclose(actual, reference))

def test_cat_removed_all_empty(self) -> None:
model = TestCat()
model.eval()
example_inputs = (torch.empty((0, 6)), torch.empty((0, 6)), torch.empty((0, 6)))
ep = torch.export.export(model, example_inputs, strict=True)
etpm = to_edge(ep).to_executorch(
config=ExecutorchBackendConfig(
remove_view_copy=False,
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
),
)

for node in etpm.exported_program().graph_module.graph.nodes:
self.assertFalse(
node.target
in [exir_ops.edge.aten.cat.default, torch.ops.aten.cat.default]
)

actual = etpm.exported_program().module()(*example_inputs)

reference = model(*example_inputs)

self.assertTrue(torch.allclose(actual, reference))