Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions backends/xnnpack/test/ops/test_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,15 @@ def test_qs8_cat_gt_5(self):
inputs.append(torch.randn(1, 2, 3))
self._test_cat(self.Cat(), tuple(inputs), cat_num=num_inputs, quant=True)

def test_qs8_cat_with_empty_tensor(self):
inputs = (
torch.randn(0, 2, 3),
torch.randn(1, 2, 3),
torch.randn(3, 2, 3),
torch.randn(0, 2, 3),
)
self._test_cat(self.Cat(), inputs, cat_num=4, quant=True)

class CatNegativeDim(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
13 changes: 13 additions & 0 deletions exir/passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ python_library(
":memory_planning_pass",
":normalize_transpose_pass",
":prim_ops_py_registry",
":prune_empty_tensor_pass",
":quant_fusion_pass",
":quantize_io_pass",
":remove_noop_pass",
Expand Down Expand Up @@ -197,6 +198,18 @@ python_library(
],
)

python_library(
name = "prune_empty_tensor_pass",
srcs = [
"prune_empty_tensors_pass.py",
],
deps = [
"//caffe2:torch",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
],
)

python_library(
name = "remove_mixed_type_operators",
srcs = [
Expand Down
2 changes: 2 additions & 0 deletions exir/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from executorch.exir.passes.memory_format_ops_pass import MemoryFormatOpsPass
from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
from executorch.exir.passes.normalize_transpose_pass import NormalizeTransposePass
from executorch.exir.passes.prune_empty_tensors_pass import PruneEmptyTensorsPass
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
from executorch.exir.passes.remove_noop_pass import RemoveNoopPass, RemoveToCopyPass
from executorch.exir.passes.replace_aten_with_edge_pass import OpReplacePass
Expand Down Expand Up @@ -486,6 +487,7 @@ def dead_code_elimination_pass(graph_module: torch.fx.GraphModule) -> PassResult
ScalarToTensorPass(),
SymToTensorPass(),
RemoveNoopPass(),
PruneEmptyTensorsPass(),
RemoveToCopyPass(),
]
).passes
Expand Down
66 changes: 66 additions & 0 deletions exir/passes/prune_empty_tensors_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx import GraphModule, Node

# This is a list of ops that are No Ops if used with an empty tensor.
# Which means that if we remove the empty tensor as input to this op,
# The result of the operation will stay the same


class PruneEmptyTensorsPass(ExportPass):
"""
Removes Any empty tensors from the graph that can safely be removed
without affecting the results of the graph. Currently we remove empty
tensor operations from the following ops:
- aten.cat.default
"""

def remove_empty_tensors_from_cat(
self, graph_module: GraphModule, cat_node: Node
) -> None:
"""
Removes empty tensors from the graph that are inputs to aten.cat.default
"""
concat_list = cat_node.all_input_nodes
pruned_concat_list = []
for input_arg in concat_list:
input_arg_tensor = input_arg.meta["val"]
if input_arg_tensor.numel() != 0:
pruned_concat_list.append(input_arg)

cat_node.args = (pruned_concat_list,) + cat_node.args[1:]
if len(pruned_concat_list) == 0:
# if all the inputs to the cat are empty tensors, then we can replace
# this concat node with an aten full like
cat_tensor = cat_node.meta["val"]
with graph_module.graph.inserting_after(cat_node):
full_like = graph_module.graph.create_node(
"call_function",
target=exir_ops.edge.aten.full.default,
args=(tuple(cat_tensor.shape), 0),
kwargs={"dtype": cat_tensor.dtype},
)
full_like.meta = cat_node.meta
cat_node.replace_all_uses_with(full_like)

def call(self, graph_module: GraphModule) -> PassResult:
for node in graph_module.graph.nodes:
if node.op != "call_function":
continue

if node.target == torch.ops.aten.cat.default:
self.remove_empty_tensors_from_cat(graph_module, node)

graph_module.graph.eliminate_dead_code()
graph_module.graph.lint()

return PassResult(graph_module, True)
14 changes: 14 additions & 0 deletions exir/tests/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,20 @@ python_unittest(
],
)

python_unittest(
name = "test_prune_empty_tensors",
srcs = [
"test_prune_empty_tensors_pass.py",
],
deps = [
"//caffe2:torch",
"//executorch/exir:lib",
"//executorch/exir:memory",
"//executorch/exir/capture:config",
"//executorch/exir/passes:lib",
],
)

python_unittest(
name = "warnings",
srcs = [
Expand Down
77 changes: 77 additions & 0 deletions exir/tests/test_prune_empty_tensors_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
import torch.nn as nn
from executorch.exir import to_edge
from executorch.exir.capture._config import ExecutorchBackendConfig
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.passes import MemoryPlanningPass


class TestCat(nn.Module):
def forward(self, x, y, z):
empty = torch.empty((0, 6))
return torch.cat([empty, x, empty, y, z, empty])

def get_example_inputs(self):
return (torch.rand(5, 6), torch.rand(5, 6), torch.rand(5, 6))


class TestPruneEmptyTensors(unittest.TestCase):
def test_empty_tensor_removed_from_cat(self) -> None:
model = TestCat()
model.eval()
example_inputs = model.get_example_inputs()
ep = torch.export.export(model, example_inputs, strict=True)
etpm = to_edge(ep).to_executorch(
config=ExecutorchBackendConfig(
remove_view_copy=False,
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
),
)

for node in etpm.exported_program().graph_module.graph.nodes:
if node.target in [
exir_ops.edge.aten.cat.default,
torch.ops.aten.cat.default,
]:
self.assertTrue(len(node.all_input_nodes) == 3)
for input_arg in node.all_input_nodes:
tensor_val = input_arg.meta["val"]
self.assertTrue(tensor_val.numel() != 0)

actual = etpm.exported_program().module()(*example_inputs)

reference = model(*example_inputs)

self.assertTrue(torch.allclose(actual, reference))

def test_cat_removed_all_empty(self) -> None:
model = TestCat()
model.eval()
example_inputs = (torch.empty((0, 6)), torch.empty((0, 6)), torch.empty((0, 6)))
ep = torch.export.export(model, example_inputs, strict=True)
etpm = to_edge(ep).to_executorch(
config=ExecutorchBackendConfig(
remove_view_copy=False,
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
),
)

for node in etpm.exported_program().graph_module.graph.nodes:
self.assertFalse(
node.target
in [exir_ops.edge.aten.cat.default, torch.ops.aten.cat.default]
)

actual = etpm.exported_program().module()(*example_inputs)

reference = model(*example_inputs)

self.assertTrue(torch.allclose(actual, reference))
Loading