Skip to content

Commit 6eae66d

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
Exclude upsample_bilinear2d.vec and nearest2d.vec from default export decomposition table (pytorch#141791)
Summary: As upsample_bilinear2d.vec and upsample_nearest2d.vec are core ATen ops, they should not be decomposed by default in the export path. Because the operators have CompositeImplicitAutograd dispatch, their decomposition is registered by default. This change adds an override list for CIA decompositions being registered in the default decomp table. In the long-term, we likely will want to exclude decompositions for all core-tagged CIA ops, but this will require all consumers to be ready to handle the remaining two ops, avg_pool1d, and adaptive_avg_pool1d. Until they are ready, I believe an explicit override list is the safest option. Additionally, I've also removed the ExecuTorch XNNPACK delegate ConvertToUpsampleBilinear2d pass, as the pass breaks (and is not needed), given that the op is not decomposed. The purpose of this pass was originally to pattern match the decomposition and recompose it, but this is no longer necessary. X-link: pytorch/executorch#7126 Test Plan: Added a new test (`test_default_decomposition_core_cia_ops`) in test_export.py to verify that upsample_bilinear2d.vec (and in the future, other core-tagged CIA ops) are not decomposed by default. Also, I manually validated end to end with ExecuTorch that the op is not decomposed in to_edge (see N6238522). ``` buck test //caffe2/test:test_export -- test_default_decomposition_core_cia_ops ``` Differential Revision: D66575454
1 parent 64e54d5 commit 6eae66d

File tree

2 files changed

+62
-8
lines changed

2 files changed

+62
-8
lines changed

test/export/test_export.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11613,6 +11613,48 @@ def forward(self, x):
1161311613
ref_res = module(*dyn_inp)
1161411614
self.assertEqual(export_res, ref_res)
1161511615

11616+
def test_default_decomposition_core_cia_ops(self):
11617+
"""
11618+
Verify that core ATen ops with Composite Implicit Autograd dispatch are not
11619+
decomposed by default.
11620+
"""
11621+
11622+
# TODO Add avg_pool1d, and adaptive_avg_pool1d when ready.
11623+
# See issue #116684.
11624+
core_cia_ops = {
11625+
"torch.ops.aten.upsample_bilinear2d.vec": (
11626+
torch.ops.aten.upsample_bilinear2d.vec,
11627+
{
11628+
"align_corners": False,
11629+
"scale_factors": [2, 2],
11630+
"output_size": None,
11631+
},
11632+
),
11633+
"torch.ops.aten.upsample_nearest2d.vec": (
11634+
torch.ops.aten.upsample_nearest2d.vec,
11635+
{
11636+
"scale_factors": [2, 2],
11637+
"output_size": None,
11638+
},
11639+
),
11640+
}
11641+
11642+
for op_name, (op, kwargs) in core_cia_ops.items():
11643+
11644+
class M(torch.nn.Module):
11645+
def forward(self, x):
11646+
return op(x, **kwargs)
11647+
11648+
ep = export(M(), (torch.randn(2, 3, 4, 5),))
11649+
FileCheck().check_count(op_name, 1, exactly=True).run(ep.graph_module.code)
11650+
11651+
decomp_table = default_decompositions()
11652+
11653+
ep = ep.run_decompositions(
11654+
decomp_table=decomp_table,
11655+
)
11656+
FileCheck().check_count(op_name, 1, exactly=True).run(ep.graph_module.code)
11657+
1161611658

1161711659
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
1161811660
class TestOneOffModelExportResult(TestCase):
@@ -12185,30 +12227,30 @@ def forward(self, x):
1218512227
)
1218612228

1218712229
def test_preserve_cia_op(self):
12188-
class StaticResizeBilinear2dModule(torch.nn.Module):
12230+
class StaticResizeTrilinear2dModule(torch.nn.Module):
1218912231
def forward(self, x):
1219012232
a = torch.nn.functional.interpolate(
1219112233
x,
12192-
size=(x.shape[2] * 2, x.shape[3] * 3),
12193-
mode="bilinear",
12234+
size=(x.shape[2] * 2, x.shape[3] * 3, x.shape[4] * 4),
12235+
mode="trilinear",
1219412236
align_corners=False,
1219512237
antialias=False,
1219612238
)
1219712239
return a
1219812240

12199-
ep = export(StaticResizeBilinear2dModule(), (torch.randn(2, 3, 4, 5),))
12241+
ep = export(StaticResizeTrilinear2dModule(), (torch.randn(2, 3, 4, 5, 6),))
1220012242
FileCheck().check_count(
12201-
"torch.ops.aten.upsample_bilinear2d.vec", 1, exactly=True
12243+
"torch.ops.aten.upsample_trilinear3d.vec", 1, exactly=True
1220212244
).run(ep.graph_module.code)
1220312245

1220412246
decomp_table = default_decompositions()
12205-
del decomp_table[torch.ops.aten.upsample_bilinear2d.vec]
12247+
del decomp_table[torch.ops.aten.upsample_trilinear3d.vec]
1220612248
ep = ep.run_decompositions(
1220712249
decomp_table=decomp_table,
1220812250
)
1220912251

1221012252
FileCheck().check_count(
12211-
"torch.ops.aten.upsample_bilinear2d.vec", 1, exactly=True
12253+
"torch.ops.aten.upsample_trilinear3d.vec", 1, exactly=True
1221212254
).run(ep.graph_module.code)
1221312255

1221412256

torch/export/decomp_utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,17 @@
1313
__all__ = ["CustomDecompTable"]
1414

1515

16+
"""
17+
Core ATen ops with Composite Implicit Autograd dispatch that should be excluded from decomposition
18+
by default. The decomposition logic should eventually exclude all core-tagged CIA ops, but until all
19+
backends are ready, this list allows opt-in one at a time.
20+
"""
21+
PRESERVED_ATEN_CIA_OPS = {
22+
torch.ops.aten.upsample_bilinear2d.vec,
23+
torch.ops.aten.upsample_nearest2d.vec,
24+
}
25+
26+
1627
class CustomDecompTable(Dict[torch._ops.OperatorBase, Callable]):
1728
"""
1829
This is a custom dictionary that is specifically used for handling decomp_table in export.
@@ -38,7 +49,8 @@ def __init__(self):
3849
self.decomp_table = _core_aten_decompositions_post_autograd()
3950

4051
for op in _collect_all_valid_cia_ops_for_aten_namespace():
41-
self.decomp_table[op] = _get_decomp_for_cia(op)
52+
if op not in PRESERVED_ATEN_CIA_OPS:
53+
self.decomp_table[op] = _get_decomp_for_cia(op)
4254

4355
# This is to track the *pending* deleted custom ops that haven't been materialized yet
4456
self.deleted_custom_ops = set()

0 commit comments

Comments
 (0)