diff --git a/exir/program/test/test_program.py b/exir/program/test/test_program.py index d07972f971a..298c40f436f 100644 --- a/exir/program/test/test_program.py +++ b/exir/program/test/test_program.py @@ -8,6 +8,7 @@ import copy import unittest +from collections.abc import Iterable from typing import Any, Dict import torch @@ -21,6 +22,9 @@ from executorch.exir.lowered_backend_module import get_lowered_submodules from executorch.exir.pass_base import ExportPass from executorch.exir.passes import MemoryPlanningPass +from executorch.exir.passes.replace_aten_with_edge_pass import ( + aten_to_edge, +) from executorch.exir.program._program import ( EdgeProgramManager, ExecutorchProgramManager, @@ -41,6 +45,15 @@ from torch.nn import functional as F +def count_nodes(graph_module, target): + targets = target if isinstance(target, Iterable) else [target] + + count = 0 + for node in graph_module.graph.nodes: + if node.op == "call_function" and node.target in targets: + count += 1 + return count + class TestLinear(torch.nn.Module): def __init__(self): super().__init__() @@ -662,13 +675,6 @@ def _get_random_inputs(cls): partitioner=[NonDecompTestPartitioner()], ) - def count_nodes(graph_module, target): - count = 0 - for node in graph_module.graph.nodes: - if node.op == "call_function" and node.target == target: - count += 1 - return count - # There should be 1 call_delegate node and 1 node for aten.mm.default for the # linear that doesn't have a bias which was decomposed as the partitioner # said this node wasn't supported. @@ -723,13 +729,6 @@ def _test_to_edge_with_preserved_ops( ): edge = to_edge_with_preserved_ops(program, preserve_ops=preserved_ops) - def count_nodes(graph_module, target): - count = 0 - for node in graph_module.graph.nodes: - if node.op == "call_function" and node.target in target: - count += 1 - return count - aten_ops_non_decomposed = count_nodes( program.graph_module, preserved_ops, @@ -811,3 +810,31 @@ def test_save_fails(self): et = edge.to_executorch() with self.assertRaises(ValueError): _ = et.save("/tmp/test_save.pt") + + def test_additional_decomposed_ops(self): + """ + Validate that EXECUTORCH_ADDITIONAL_DECOMPOSED_OPS are decomposed. + """ + class TestModel(torch.nn.Module): + def forward(self, x): + y = torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest") + y = torch.nn.functional.interpolate(y, scale_factor=2, mode="bilinear") + return y + + test_ops = [ + torch.ops.aten.upsample_bilinear2d.vec, + torch.ops.aten.upsample_nearest2d.vec + ] + inputs = (torch.randn(1, 1, 4, 4),) + program = torch.export.export(TestModel(), inputs) + + for op in test_ops: + self.assertEqual(1, count_nodes(program.graph_module, op)) + + edge1 = to_edge(program) + edge2 = to_edge_transform_and_lower(program) + + for edge in [edge1, edge2]: + for op in test_ops: + edge_op = aten_to_edge(op) + self.assertEqual(0, count_nodes(edge.exported_program().graph_module, edge_op)) diff --git a/exir/tracer.py b/exir/tracer.py index 82f93424a14..58f25080a3b 100644 --- a/exir/tracer.py +++ b/exir/tracer.py @@ -64,6 +64,16 @@ torchdynamo_enabled = False +""" +Additional decompositions to apply by during to_edge or +to to_edge_transform_and_lower in addition to the default decompositions from +PyTorch export. +""" +EXECUTORCH_ADDITIONAL_DECOMPOSITIONS = [ + torch.ops.aten.upsample_bilinear2d.vec, + torch.ops.aten.upsample_nearest2d.vec, +] + def get_stacktrace() -> List[Dict[str, str]]: """ @@ -631,8 +641,12 @@ def _default_decomposition_table( ] # pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.e... return get_decompositions(decomp_opset) + # pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.exir.... - return default_decompositions() + table = default_decompositions() + additional_decompositions = get_decompositions(EXECUTORCH_ADDITIONAL_DECOMPOSITIONS) + table.decomp_table.update(additional_decompositions) + return table def dynamo_trace(