Skip to content

Commit 6b0713b

Browse files
chunnienccopybara-github
authored andcommitted
torch 2.6.0 upgrade patch
All patches to make ai-edge-torch working with torch 2.6.0 1. `capture_pre_autograd_graph` is removed in torch 2.6.0. Replace this with `export_for_training` in PT2E testing. Tutorial updates may come in later CLs. 2. `_decomp_table_to_post_autograd_aten` (added in 2.5.0) is removed in 2.6.0. For now there is no built-in decomp table to transform FX into the state we expect as in 2.4.0. Therefore a new core-aten based decomp regitry system is added in this CL to make decomp behavior expected in 2.4.0, 2.5.0, and 2.6.0. Tests are passing in 2.5.0 and 2.6.0 but impact on converted model performance need more investigation. 3. 2.6.0 introduce `_assert_tensor_metadata` op to guard custom ops in exported graph. This will break pattern matching and lowering. Add a `fx_infra` with `safe_run_decompositions` to exclude this op and all the other unexpected export behavior. PiperOrigin-RevId: 718173217
1 parent 34296a8 commit 6b0713b

35 files changed

+461
-237
lines changed

ai_edge_torch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
1615
from ai_edge_torch._config import config
1716
from ai_edge_torch._convert.converter import convert
1817
from ai_edge_torch._convert.converter import signature
1918
from ai_edge_torch._convert.to_channel_last_io import to_channel_last_io
2019
from ai_edge_torch.model import Model
2120
from ai_edge_torch.version import __version__
2221

22+
2323
def load(path: str) -> Model:
2424
"""Imports an ai_edge_torch model from disk.
2525

ai_edge_torch/_convert/conversion.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from typing import Any, Literal, Optional, Union
1818

1919
import ai_edge_torch
20-
from ai_edge_torch import fx_pass_base
20+
from ai_edge_torch import fx_infra
2121
from ai_edge_torch import lowertools
2222
from ai_edge_torch import model
2323
from ai_edge_torch._convert import fx_passes
@@ -53,7 +53,7 @@ def _run_convert_passes(
5353
fx_passes.CanonicalizePass(),
5454
]
5555

56-
exported_program = fx_pass_base.run_passes(exported_program, passes)
56+
exported_program = fx_infra.run_passes(exported_program, passes)
5757
return exported_program
5858

5959

@@ -125,14 +125,10 @@ def export(**kwargs):
125125
else:
126126
exported_program = torch.export.export(**kwargs, strict=True)
127127

128-
if hasattr(torch._decomp, "_decomp_table_to_post_autograd_aten"):
129-
# Available after torch 2.5.0: `_decomp_table_to_post_autograd_aten` is a
130-
# stop-gap table which replicates the old behaviour of post-dispatch IR.
131-
# This could help ensure the collection of aten ops remaining still as the
132-
# implementation of torch.export changes.
133-
exported_program = exported_program.run_decompositions(
134-
torch._decomp._decomp_table_to_post_autograd_aten()
135-
)
128+
exported_program = fx_infra.safe_run_decompositions(
129+
exported_program,
130+
fx_infra.decomp.pre_convert_decomp(),
131+
)
136132
return exported_program
137133

138134
exported_programs: torch.export.ExportedProgram = [

ai_edge_torch/_convert/fx_passes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@
2020
from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass
2121
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass
2222
from ai_edge_torch._convert.fx_passes.remove_non_user_outputs_pass import RemoveNonUserOutputsPass
23-
from ai_edge_torch.fx_pass_base import CanonicalizePass
23+
from ai_edge_torch.fx_infra import CanonicalizePass

ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# ==============================================================================
1515

1616
from typing import Any, Callable
17-
from ai_edge_torch import fx_pass_base
17+
from ai_edge_torch import fx_infra
1818
from ai_edge_torch import lowertools
1919
import torch
2020
import torch.utils._pytree as pytree
@@ -25,6 +25,9 @@
2525

2626

2727
def _register_composite_builder(op):
28+
# Remove op from pre_convert_decomp to keep this in the decomposed graph.
29+
fx_infra.decomp.remove_pre_convert_decomp(op)
30+
2831
def inner(func):
2932
if isinstance(op, torch._ops.OpOverloadPacket):
3033
for overload in op.overloads():
@@ -276,7 +279,7 @@ def embedding(*args, **kwargs):
276279
node.target = embedding
277280

278281

279-
class BuildAtenCompositePass(fx_pass_base.PassBase):
282+
class BuildAtenCompositePass(fx_infra.PassBase):
280283

281284
def call(self, graph_module: torch.fx.GraphModule):
282285
for node in graph_module.graph.nodes:
@@ -285,4 +288,4 @@ def call(self, graph_module: torch.fx.GraphModule):
285288

286289
graph_module.graph.lint()
287290
graph_module.recompile()
288-
return fx_pass_base.PassResult(graph_module, True)
291+
return fx_infra.PassResult(graph_module, True)

ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import functools
1818

19-
from ai_edge_torch import fx_pass_base
19+
from ai_edge_torch import fx_infra
2020
from ai_edge_torch.hlfb import mark_pattern
2121
from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
2222
import torch
@@ -41,7 +41,7 @@ def _get_upsample_bilinear2d_pattern():
4141
x, scale_factor=2, mode="bilinear", align_corners=False
4242
),
4343
export_args=(torch.rand(1, 3, 100, 100),),
44-
decomp_table=_INTERPOLATE_DECOMPOSITIONS,
44+
extra_decomp_table=_INTERPOLATE_DECOMPOSITIONS,
4545
)
4646

4747
@pattern.register_attr_builder
@@ -65,7 +65,7 @@ def _get_upsample_bilinear2d_align_corners_pattern():
6565
x, scale_factor=2, mode="bilinear", align_corners=True
6666
),
6767
export_args=(torch.rand(1, 3, 100, 100),),
68-
decomp_table=_INTERPOLATE_DECOMPOSITIONS,
68+
extra_decomp_table=_INTERPOLATE_DECOMPOSITIONS,
6969
)
7070

7171
@pattern.register_attr_builder
@@ -89,7 +89,7 @@ def _get_interpolate_nearest2d_pattern():
8989
x, scale_factor=2, mode="nearest"
9090
),
9191
export_args=(torch.rand(1, 3, 100, 100),),
92-
decomp_table=_INTERPOLATE_DECOMPOSITIONS,
92+
extra_decomp_table=_INTERPOLATE_DECOMPOSITIONS,
9393
)
9494

9595
@pattern.register_attr_builder
@@ -104,7 +104,7 @@ def attr_builder(pattern, graph_module, internal_match):
104104
return pattern
105105

106106

107-
class BuildInterpolateCompositePass(fx_pass_base.ExportedProgramPassBase):
107+
class BuildInterpolateCompositePass(fx_infra.ExportedProgramPassBase):
108108

109109
def __init__(self):
110110
super().__init__()
@@ -115,11 +115,9 @@ def __init__(self):
115115
]
116116

117117
def call(self, exported_program: torch.export.ExportedProgram):
118-
exported_program = fx_pass_base.run_passes(
119-
exported_program, [fx_pass_base.CanonicalizePass()]
120-
)
121-
exported_program = exported_program.run_decompositions(
122-
_INTERPOLATE_DECOMPOSITIONS
118+
exported_program = fx_infra.safe_run_decompositions(
119+
exported_program,
120+
_INTERPOLATE_DECOMPOSITIONS,
123121
)
124122

125123
graph_module = exported_program.graph_module
@@ -128,4 +126,4 @@ def call(self, exported_program: torch.export.ExportedProgram):
128126

129127
graph_module.graph.lint()
130128
graph_module.recompile()
131-
return fx_pass_base.ExportedProgramPassResult(exported_program, True)
129+
return fx_infra.ExportedProgramPassResult(exported_program, True)

ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515

16-
from ai_edge_torch import fx_pass_base
16+
from ai_edge_torch import fx_infra
1717
from ai_edge_torch import lowertools
1818
import torch
1919
import torch.utils._pytree as pytree
@@ -61,7 +61,7 @@ def debuginfo_writer(*args, **kwargs):
6161
node.target = debuginfo_writer
6262

6363

64-
class InjectMlirDebuginfoPass(fx_pass_base.PassBase):
64+
class InjectMlirDebuginfoPass(fx_infra.PassBase):
6565
"""DEPRECATED: Debuginfo is injected automatically by odml_torch."""
6666

6767
def call(self, graph_module: torch.fx.GraphModule):
@@ -70,4 +70,4 @@ def call(self, graph_module: torch.fx.GraphModule):
7070

7171
graph_module.graph.lint()
7272
graph_module.recompile()
73-
return fx_pass_base.PassResult(graph_module, True)
73+
return fx_infra.PassResult(graph_module, True)

ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515

16+
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import _decomp_registry
1617
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass.pass_body import OptimizeLayoutTransposesPass # NOQA
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2025 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Remove decompositions for ops to keep in layout optimization."""
16+
from ai_edge_torch import fx_infra
17+
import torch
18+
19+
__all__ = []
20+
21+
aten = torch.ops.aten
22+
23+
_OPS_TO_KEEP = [
24+
aten.conv2d,
25+
aten.max_pool2d,
26+
aten._softmax.default,
27+
aten.group_norm.default,
28+
aten.native_group_norm.default,
29+
aten.reflection_pad2d.default,
30+
]
31+
32+
for op in _OPS_TO_KEEP:
33+
fx_infra.decomp.remove_pre_convert_decomp(op)

ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
class OpFuncRegistry(dict):
2121

2222
def register(self, op):
23+
2324
ops = utils.flatten_torch_op_overloads(op)
2425

2526
def inner(func):

ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import os
1919
from typing import Union
2020

21-
from ai_edge_torch import fx_pass_base
21+
from ai_edge_torch import fx_infra
2222
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
2323
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
2424
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_partitioners # NOQA
@@ -30,7 +30,7 @@
3030
TransposeFunc = Union[utils.tensor_to_nchw, utils.tensor_to_nhwc]
3131

3232

33-
class OptimizeLayoutTransposesPass(fx_pass_base.ExportedProgramPassBase):
33+
class OptimizeLayoutTransposesPass(fx_infra.ExportedProgramPassBase):
3434

3535
def get_source_meta(self, node: torch.fx.Node):
3636
keys = ["stack_trace", "nn_module_stack", "source_fn_stack", "from_node"]
@@ -300,4 +300,4 @@ def call(self, exported_program: torch.export.ExportedProgram):
300300
# Mark const node again for debugging
301301
self.mark_const_nodes(exported_program)
302302

303-
return fx_pass_base.ExportedProgramPassResult(exported_program, True)
303+
return fx_infra.ExportedProgramPassResult(exported_program, True)

0 commit comments

Comments
 (0)