Skip to content

Commit e61c9b5

Browse files
committed
Add additional default decompositions for upsample operators
Summary: There are several core ATen ops that are not yet supported on ExecuTorch, including upsample_bilinear2d.vec and upsample_nearest2d.vec. These ops are currently not decomposed by default with PyTorch export default decompositions, but should be. Existing ET consumers rely on this behavior, so we need to preserve it until we have upsample kernels ready. This change allows ET to opt-into decomposing these ops, regardless of the PyTorch default export decomposition table. This will unblock updating PyTorch with the correct behavior (see pytorch/pytorch#116684). Once the upsample kernels land in ET, we can remove these decompositions. This is currently blocked by pin bumps, which may take a while to resolve. Differential Revision: D67443180
1 parent 2ed5ce3 commit e61c9b5

File tree

2 files changed

+58
-15
lines changed

2 files changed

+58
-15
lines changed

exir/program/test/test_program.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import copy
1010
import unittest
11+
from collections.abc import Iterable
1112
from typing import Any, Dict
1213

1314
import torch
@@ -21,6 +22,7 @@
2122
from executorch.exir.lowered_backend_module import get_lowered_submodules
2223
from executorch.exir.pass_base import ExportPass
2324
from executorch.exir.passes import MemoryPlanningPass
25+
from executorch.exir.passes.replace_aten_with_edge_pass import aten_to_edge
2426
from executorch.exir.program._program import (
2527
EdgeProgramManager,
2628
ExecutorchProgramManager,
@@ -41,6 +43,16 @@
4143
from torch.nn import functional as F
4244

4345

46+
def count_nodes(graph_module, target):
47+
targets = target if isinstance(target, Iterable) else [target]
48+
49+
count = 0
50+
for node in graph_module.graph.nodes:
51+
if node.op == "call_function" and node.target in targets:
52+
count += 1
53+
return count
54+
55+
4456
class TestLinear(torch.nn.Module):
4557
def __init__(self):
4658
super().__init__()
@@ -662,13 +674,6 @@ def _get_random_inputs(cls):
662674
partitioner=[NonDecompTestPartitioner()],
663675
)
664676

665-
def count_nodes(graph_module, target):
666-
count = 0
667-
for node in graph_module.graph.nodes:
668-
if node.op == "call_function" and node.target == target:
669-
count += 1
670-
return count
671-
672677
# There should be 1 call_delegate node and 1 node for aten.mm.default for the
673678
# linear that doesn't have a bias which was decomposed as the partitioner
674679
# said this node wasn't supported.
@@ -723,13 +728,6 @@ def _test_to_edge_with_preserved_ops(
723728
):
724729
edge = to_edge_with_preserved_ops(program, preserve_ops=preserved_ops)
725730

726-
def count_nodes(graph_module, target):
727-
count = 0
728-
for node in graph_module.graph.nodes:
729-
if node.op == "call_function" and node.target in target:
730-
count += 1
731-
return count
732-
733731
aten_ops_non_decomposed = count_nodes(
734732
program.graph_module,
735733
preserved_ops,
@@ -811,3 +809,34 @@ def test_save_fails(self):
811809
et = edge.to_executorch()
812810
with self.assertRaises(ValueError):
813811
_ = et.save("/tmp/test_save.pt")
812+
813+
def test_additional_decomposed_ops(self):
814+
"""
815+
Validate that EXECUTORCH_ADDITIONAL_DECOMPOSED_OPS are decomposed.
816+
"""
817+
818+
class TestModel(torch.nn.Module):
819+
def forward(self, x):
820+
y = torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest")
821+
y = torch.nn.functional.interpolate(y, scale_factor=2, mode="bilinear")
822+
return y
823+
824+
test_ops = [
825+
torch.ops.aten.upsample_bilinear2d.vec,
826+
torch.ops.aten.upsample_nearest2d.vec,
827+
]
828+
inputs = (torch.randn(1, 1, 4, 4),)
829+
program = torch.export.export(TestModel(), inputs)
830+
831+
for op in test_ops:
832+
self.assertEqual(1, count_nodes(program.graph_module, op))
833+
834+
edge1 = to_edge(program)
835+
edge2 = to_edge_transform_and_lower(program)
836+
837+
for edge in [edge1, edge2]:
838+
for op in test_ops:
839+
edge_op = aten_to_edge(op)
840+
self.assertEqual(
841+
0, count_nodes(edge.exported_program().graph_module, edge_op)
842+
)

exir/tracer.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,16 @@
6464

6565
torchdynamo_enabled = False
6666

67+
"""
68+
Additional decompositions to apply by during to_edge or
69+
to to_edge_transform_and_lower in addition to the default decompositions from
70+
PyTorch export.
71+
"""
72+
EXECUTORCH_ADDITIONAL_DECOMPOSITIONS = [
73+
torch.ops.aten.upsample_bilinear2d.vec,
74+
torch.ops.aten.upsample_nearest2d.vec,
75+
]
76+
6777

6878
def get_stacktrace() -> List[Dict[str, str]]:
6979
"""
@@ -631,8 +641,12 @@ def _default_decomposition_table(
631641
]
632642
# pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.e...
633643
return get_decompositions(decomp_opset)
644+
634645
# pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.exir....
635-
return default_decompositions()
646+
table = default_decompositions()
647+
additional_decompositions = get_decompositions(EXECUTORCH_ADDITIONAL_DECOMPOSITIONS)
648+
table.decomp_table.update(additional_decompositions)
649+
return table
636650

637651

638652
def dynamo_trace(

0 commit comments

Comments
 (0)