Skip to content

Commit 8fe514f

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
Add additional default decompositions for upsample operators
Summary: There are several core ATen ops that are not yet supported on ExecuTorch, including upsample_bilinear2d.vec and upsample_nearest2d.vec. These ops are currently not decomposed by default with PyTorch export default decompositions, but should be. Existing ET consumers rely on this behavior, so we need to preserve it until we have upsample kernels ready. This change allows ET to opt-into decomposing these ops, regardless of the PyTorch default export decomposition table. This will unblock updating PyTorch with the correct behavior (see pytorch/pytorch#116684). Once the upsample kernels land in ET, we can remove these decompositions. This is currently blocked by pin bumps, which may take a while to resolve. Differential Revision: D67443180
1 parent 2ed5ce3 commit 8fe514f

File tree

2 files changed

+56
-15
lines changed

2 files changed

+56
-15
lines changed

exir/program/test/test_program.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import copy
1010
import unittest
11+
from collections.abc import Iterable
1112
from typing import Any, Dict
1213

1314
import torch
@@ -21,6 +22,9 @@
2122
from executorch.exir.lowered_backend_module import get_lowered_submodules
2223
from executorch.exir.pass_base import ExportPass
2324
from executorch.exir.passes import MemoryPlanningPass
25+
from executorch.exir.passes.replace_aten_with_edge_pass import (
26+
aten_to_edge,
27+
)
2428
from executorch.exir.program._program import (
2529
EdgeProgramManager,
2630
ExecutorchProgramManager,
@@ -41,6 +45,15 @@
4145
from torch.nn import functional as F
4246

4347

48+
def count_nodes(graph_module, target):
49+
targets = target if isinstance(target, Iterable) else [target]
50+
51+
count = 0
52+
for node in graph_module.graph.nodes:
53+
if node.op == "call_function" and node.target in targets:
54+
count += 1
55+
return count
56+
4457
class TestLinear(torch.nn.Module):
4558
def __init__(self):
4659
super().__init__()
@@ -662,13 +675,6 @@ def _get_random_inputs(cls):
662675
partitioner=[NonDecompTestPartitioner()],
663676
)
664677

665-
def count_nodes(graph_module, target):
666-
count = 0
667-
for node in graph_module.graph.nodes:
668-
if node.op == "call_function" and node.target == target:
669-
count += 1
670-
return count
671-
672678
# There should be 1 call_delegate node and 1 node for aten.mm.default for the
673679
# linear that doesn't have a bias which was decomposed as the partitioner
674680
# said this node wasn't supported.
@@ -723,13 +729,6 @@ def _test_to_edge_with_preserved_ops(
723729
):
724730
edge = to_edge_with_preserved_ops(program, preserve_ops=preserved_ops)
725731

726-
def count_nodes(graph_module, target):
727-
count = 0
728-
for node in graph_module.graph.nodes:
729-
if node.op == "call_function" and node.target in target:
730-
count += 1
731-
return count
732-
733732
aten_ops_non_decomposed = count_nodes(
734733
program.graph_module,
735734
preserved_ops,
@@ -811,3 +810,31 @@ def test_save_fails(self):
811810
et = edge.to_executorch()
812811
with self.assertRaises(ValueError):
813812
_ = et.save("/tmp/test_save.pt")
813+
814+
def test_additional_decomposed_ops(self):
815+
"""
816+
Validate that EXECUTORCH_ADDITIONAL_DECOMPOSED_OPS are decomposed.
817+
"""
818+
class TestModel(torch.nn.Module):
819+
def forward(self, x):
820+
y = torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest")
821+
y = torch.nn.functional.interpolate(y, scale_factor=2, mode="bilinear")
822+
return y
823+
824+
test_ops = [
825+
torch.ops.aten.upsample_bilinear2d.vec,
826+
torch.ops.aten.upsample_nearest2d.vec
827+
]
828+
inputs = (torch.randn(1, 1, 4, 4),)
829+
program = torch.export.export(TestModel(), inputs)
830+
831+
for op in test_ops:
832+
self.assertEqual(1, count_nodes(program.graph_module, op))
833+
834+
edge1 = to_edge(program)
835+
edge2 = to_edge_transform_and_lower(program)
836+
837+
for edge in [edge1, edge2]:
838+
for op in test_ops:
839+
edge_op = aten_to_edge(op)
840+
self.assertEqual(0, count_nodes(edge.exported_program().graph_module, edge_op))

exir/tracer.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,16 @@
6464

6565
torchdynamo_enabled = False
6666

67+
"""
68+
Additional decompositions to apply by during to_edge or
69+
to to_edge_transform_and_lower in addition to the default decompositions from
70+
PyTorch export.
71+
"""
72+
EXECUTORCH_ADDITIONAL_DECOMPOSITIONS = [
73+
torch.ops.aten.upsample_bilinear2d.vec,
74+
torch.ops.aten.upsample_nearest2d.vec,
75+
]
76+
6777

6878
def get_stacktrace() -> List[Dict[str, str]]:
6979
"""
@@ -631,8 +641,12 @@ def _default_decomposition_table(
631641
]
632642
# pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.e...
633643
return get_decompositions(decomp_opset)
644+
634645
# pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.exir....
635-
return default_decompositions()
646+
table = default_decompositions()
647+
additional_decompositions = get_decompositions(EXECUTORCH_ADDITIONAL_DECOMPOSITIONS)
648+
table.decomp_table.update(additional_decompositions)
649+
return table
636650

637651

638652
def dynamo_trace(

0 commit comments

Comments
 (0)