Skip to content

Commit d421658

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
Exclude upsample_bilinear2d.vec and nearest2d.vec from default export decomposition table (pytorch#141791)
Summary: As upsample_bilinear2d.vec and upsample_nearest2d.vec are core ATen ops, they should not be decomposed by default in the export path. Because the operators have CompositeImplicitAutograd dispatch, their decomposition is registered by default. This change adds an override list for CIA decompositions being registered in the default decomp table. In the long-term, we likely will want to exclude decompositions for all core-tagged CIA ops, but this will require all consumers to be ready to handle the remaining two ops, avg_pool1d, and adaptive_avg_pool1d. Until they are ready, I believe an explicit override list is the safest option. Additionally, I've also removed the ExecuTorch XNNPACK delegate ConvertToUpsampleBilinear2d pass, as the pass breaks (and is not needed), given that the op is not decomposed. The purpose of this pass was originally to pattern match the decomposition and recompose it, but this is no longer necessary. X-link: pytorch/executorch#7126 Test Plan: Added a new test (`test_default_decomposition_core_cia_ops`) in test_export.py to verify that upsample_bilinear2d.vec (and in the future, other core-tagged CIA ops) are not decomposed by default. Also, I manually validated end to end with ExecuTorch that the op is not decomposed in to_edge (see N6238522). ``` buck test //caffe2/test:test_export -- test_default_decomposition_core_cia_ops ``` Reviewed By: digantdesai Differential Revision: D66575454
1 parent 9c78fb9 commit d421658

File tree

2 files changed

+62
-8
lines changed

2 files changed

+62
-8
lines changed

test/export/test_export.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11918,6 +11918,48 @@ def forward(self, x):
1191811918
]
1191911919
self.assertEqual(len(shift_op), 1)
1192011920

11921+
def test_default_decomposition_core_cia_ops(self):
11922+
"""
11923+
Verify that core ATen ops with Composite Implicit Autograd dispatch are not
11924+
decomposed by default.
11925+
"""
11926+
11927+
# TODO Add avg_pool1d, and adaptive_avg_pool1d when ready.
11928+
# See issue #116684.
11929+
core_cia_ops = {
11930+
"torch.ops.aten.upsample_bilinear2d.vec": (
11931+
torch.ops.aten.upsample_bilinear2d.vec,
11932+
{
11933+
"align_corners": False,
11934+
"scale_factors": [2, 2],
11935+
"output_size": None,
11936+
},
11937+
),
11938+
"torch.ops.aten.upsample_nearest2d.vec": (
11939+
torch.ops.aten.upsample_nearest2d.vec,
11940+
{
11941+
"scale_factors": [2, 2],
11942+
"output_size": None,
11943+
},
11944+
),
11945+
}
11946+
11947+
for op_name, (op, kwargs) in core_cia_ops.items():
11948+
11949+
class M(torch.nn.Module):
11950+
def forward(self, x):
11951+
return op(x, **kwargs)
11952+
11953+
ep = export(M(), (torch.randn(2, 3, 4, 5),))
11954+
FileCheck().check_count(op_name, 1, exactly=True).run(ep.graph_module.code)
11955+
11956+
decomp_table = default_decompositions()
11957+
11958+
ep = ep.run_decompositions(
11959+
decomp_table=decomp_table,
11960+
)
11961+
FileCheck().check_count(op_name, 1, exactly=True).run(ep.graph_module.code)
11962+
1192111963

1192211964
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
1192311965
class TestOneOffModelExportResult(TestCase):
@@ -12523,30 +12565,30 @@ def forward(self, x):
1252312565
torch.distributed.destroy_process_group()
1252412566

1252512567
def test_preserve_cia_op(self):
12526-
class StaticResizeBilinear2dModule(torch.nn.Module):
12568+
class StaticResizeTrilinear2dModule(torch.nn.Module):
1252712569
def forward(self, x):
1252812570
a = torch.nn.functional.interpolate(
1252912571
x,
12530-
size=(x.shape[2] * 2, x.shape[3] * 3),
12531-
mode="bilinear",
12572+
size=(x.shape[2] * 2, x.shape[3] * 3, x.shape[4] * 4),
12573+
mode="trilinear",
1253212574
align_corners=False,
1253312575
antialias=False,
1253412576
)
1253512577
return a
1253612578

12537-
ep = export(StaticResizeBilinear2dModule(), (torch.randn(2, 3, 4, 5),))
12579+
ep = export(StaticResizeTrilinear2dModule(), (torch.randn(2, 3, 4, 5, 6),))
1253812580
FileCheck().check_count(
12539-
"torch.ops.aten.upsample_bilinear2d.vec", 1, exactly=True
12581+
"torch.ops.aten.upsample_trilinear3d.vec", 1, exactly=True
1254012582
).run(ep.graph_module.code)
1254112583

1254212584
decomp_table = default_decompositions()
12543-
del decomp_table[torch.ops.aten.upsample_bilinear2d.vec]
12585+
del decomp_table[torch.ops.aten.upsample_trilinear3d.vec]
1254412586
ep = ep.run_decompositions(
1254512587
decomp_table=decomp_table,
1254612588
)
1254712589

1254812590
FileCheck().check_count(
12549-
"torch.ops.aten.upsample_bilinear2d.vec", 1, exactly=True
12591+
"torch.ops.aten.upsample_trilinear3d.vec", 1, exactly=True
1255012592
).run(ep.graph_module.code)
1255112593

1255212594

torch/export/decomp_utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,17 @@
1313
__all__ = ["CustomDecompTable"]
1414

1515

16+
"""
17+
Core ATen ops with Composite Implicit Autograd dispatch that should be excluded from decomposition
18+
by default. The decomposition logic should eventually exclude all core-tagged CIA ops, but until all
19+
backends are ready, this list allows opt-in one at a time.
20+
"""
21+
PRESERVED_ATEN_CIA_OPS = {
22+
torch.ops.aten.upsample_bilinear2d.vec,
23+
torch.ops.aten.upsample_nearest2d.vec,
24+
}
25+
26+
1627
class CustomDecompTable(dict[torch._ops.OperatorBase, Callable]):
1728
"""
1829
This is a custom dictionary that is specifically used for handling decomp_table in export.
@@ -38,7 +49,8 @@ def __init__(self):
3849
self.decomp_table = _core_aten_decompositions_post_autograd()
3950

4051
for op in _collect_all_valid_cia_ops_for_aten_namespace():
41-
self.decomp_table[op] = _get_decomp_for_cia(op)
52+
if op not in PRESERVED_ATEN_CIA_OPS:
53+
self.decomp_table[op] = _get_decomp_for_cia(op)
4254

4355
# This is to track the *pending* deleted custom ops that haven't been materialized yet
4456
self.deleted_custom_ops = set()

0 commit comments

Comments
 (0)