Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 41 additions & 14 deletions exir/program/test/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import copy
import unittest
from collections.abc import Iterable
from typing import Any, Dict

import torch
Expand All @@ -21,6 +22,9 @@
from executorch.exir.lowered_backend_module import get_lowered_submodules
from executorch.exir.pass_base import ExportPass
from executorch.exir.passes import MemoryPlanningPass
from executorch.exir.passes.replace_aten_with_edge_pass import (
aten_to_edge,
)
from executorch.exir.program._program import (
EdgeProgramManager,
ExecutorchProgramManager,
Expand All @@ -41,6 +45,15 @@
from torch.nn import functional as F


def count_nodes(graph_module, target):
targets = target if isinstance(target, Iterable) else [target]

count = 0
for node in graph_module.graph.nodes:
if node.op == "call_function" and node.target in targets:
count += 1
return count

class TestLinear(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -662,13 +675,6 @@ def _get_random_inputs(cls):
partitioner=[NonDecompTestPartitioner()],
)

def count_nodes(graph_module, target):
count = 0
for node in graph_module.graph.nodes:
if node.op == "call_function" and node.target == target:
count += 1
return count

# There should be 1 call_delegate node and 1 node for aten.mm.default for the
# linear that doesn't have a bias which was decomposed as the partitioner
# said this node wasn't supported.
Expand Down Expand Up @@ -723,13 +729,6 @@ def _test_to_edge_with_preserved_ops(
):
edge = to_edge_with_preserved_ops(program, preserve_ops=preserved_ops)

def count_nodes(graph_module, target):
count = 0
for node in graph_module.graph.nodes:
if node.op == "call_function" and node.target in target:
count += 1
return count

aten_ops_non_decomposed = count_nodes(
program.graph_module,
preserved_ops,
Expand Down Expand Up @@ -811,3 +810,31 @@ def test_save_fails(self):
et = edge.to_executorch()
with self.assertRaises(ValueError):
_ = et.save("/tmp/test_save.pt")

def test_additional_decomposed_ops(self):
"""
Validate that EXECUTORCH_ADDITIONAL_DECOMPOSED_OPS are decomposed.
"""
class TestModel(torch.nn.Module):
def forward(self, x):
y = torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest")
y = torch.nn.functional.interpolate(y, scale_factor=2, mode="bilinear")
return y

test_ops = [
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.upsample_nearest2d.vec
]
inputs = (torch.randn(1, 1, 4, 4),)
program = torch.export.export(TestModel(), inputs)

for op in test_ops:
self.assertEqual(1, count_nodes(program.graph_module, op))

edge1 = to_edge(program)
edge2 = to_edge_transform_and_lower(program)

for edge in [edge1, edge2]:
for op in test_ops:
edge_op = aten_to_edge(op)
self.assertEqual(0, count_nodes(edge.exported_program().graph_module, edge_op))
16 changes: 15 additions & 1 deletion exir/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,16 @@

torchdynamo_enabled = False

"""
Additional decompositions to apply by during to_edge or
to to_edge_transform_and_lower in addition to the default decompositions from
PyTorch export.
"""
EXECUTORCH_ADDITIONAL_DECOMPOSITIONS = [
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.upsample_nearest2d.vec,
]


def get_stacktrace() -> List[Dict[str, str]]:
"""
Expand Down Expand Up @@ -631,8 +641,12 @@ def _default_decomposition_table(
]
# pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.e...
return get_decompositions(decomp_opset)

# pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.exir....
return default_decompositions()
table = default_decompositions()
additional_decompositions = get_decompositions(EXECUTORCH_ADDITIONAL_DECOMPOSITIONS)
table.decomp_table.update(additional_decompositions)
return table


def dynamo_trace(
Expand Down
Loading