Skip to content

Commit 56b4375

Browse files
chunnienccopybara-github
authored andcommitted
Fix ai_edge_torch for g3 torch upgrade
PiperOrigin-RevId: 845961480
1 parent 654788f commit 56b4375

File tree

3 files changed

+113
-2
lines changed

3 files changed

+113
-2
lines changed

ai_edge_torch/fx_infra/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,4 @@
3030

3131
CanonicalizePass = _canonicalize_pass.CanonicalizePass
3232
safe_run_decompositions = _safe_run_decompositions.safe_run_decompositions
33+
annotate_force_decomp = _safe_run_decompositions.annotate_force_decomp

ai_edge_torch/fx_infra/_safe_run_decompositions.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# ==============================================================================
1515
"""ExportedProgram.run_decompositions wrapper to handle unexpected export behavior."""
1616
import operator
17+
from typing import Any, Callable
1718
import torch
1819

1920

@@ -59,6 +60,15 @@ def _require_decomp(
5960
return False
6061

6162

63+
_FORCE_DECOMP_ATTR = "_ai_edge_torch_force_decomp"
64+
65+
66+
def annotate_force_decomp(decomp: Callable[..., Any]):
67+
"""Annotates a decomp to force it to be run (at least shallowly) in safe_run_decompositions."""
68+
setattr(decomp, _FORCE_DECOMP_ATTR, _FORCE_DECOMP_ATTR)
69+
return decomp
70+
71+
6272
def safe_run_decompositions(exported_program, decomp_table=None, can_skip=True):
6373
"""Wrapper for ExportedProgram.run_decompositions to handle unexpected export behavior."""
6474

@@ -79,6 +89,14 @@ def safe_run_decompositions(exported_program, decomp_table=None, can_skip=True):
7989
# back to one aten.view.
8090
node.target = lambda self, size: torch.reshape(self.contiguous(), size)
8191

92+
# Torch may skip some decompositions even if target is in decomp_table.
93+
# The following ensures the target is always run through the decompositions
94+
# shallowly if it has _FORCE_DECOMP_ATTR.
95+
if decomp_table and node.target in decomp_table:
96+
decomp = decomp_table[node.target]
97+
if hasattr(decomp, _FORCE_DECOMP_ATTR):
98+
node.target = decomp
99+
82100
exported_program = exported_program.run_decompositions(decomp_table)
83101

84102
if hasattr(torch.ops.aten, "_assert_tensor_metadata"):

ai_edge_torch/odml_torch/lowerings/_decomp_registry.py

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,72 @@
1414
# ==============================================================================
1515
"""Torch export decompositions to run before lowering."""
1616

17+
import functools
1718
from ai_edge_torch import fx_infra
1819
import torch
1920

2021

22+
# Fork from pytorch/torch/_decomp/decompositions.py
23+
def upsample_compute_output_size(input_size, output_size, scale_factors):
24+
spatial_dimensions = len(input_size) - 2
25+
if output_size is not None:
26+
torch._check(
27+
scale_factors is None,
28+
lambda: "Must specify exactly one of output_size and scale_factors",
29+
)
30+
torch._check(len(output_size) == spatial_dimensions, lambda: "")
31+
return output_size
32+
if scale_factors is not None:
33+
# NB: this isn't necessary lol
34+
torch._check(
35+
output_size is None,
36+
lambda: "Must specify exactly one of output_size and scale_factors",
37+
)
38+
torch._check(len(scale_factors) == spatial_dimensions, lambda: "")
39+
output_size = []
40+
for i, s in enumerate(scale_factors):
41+
if int(s) == s:
42+
output_size.append(input_size[i + 2] * int(s))
43+
else:
44+
output_size.append(torch.sym_int(input_size[i + 2] * s))
45+
return output_size
46+
torch._check(
47+
False, lambda: "Must specify exactly one of output_size and scale_factors"
48+
)
49+
50+
51+
# Fork from pytorch/torch/_decomp/decompositions.py
52+
def _compute_upsample_nearest_indices(input, output_size, scales, exact=False):
53+
indices = []
54+
num_spatial_dims = len(output_size)
55+
offset = 0.5 if exact else 0.0
56+
57+
for d in range(num_spatial_dims):
58+
osize = output_size[d]
59+
isize = input.shape[-num_spatial_dims + d]
60+
scale = (
61+
isize / (isize * scales[d]) if scales[d] is not None else isize / osize
62+
)
63+
64+
output_indices = torch.arange(
65+
osize, dtype=torch.float32, device=input.device
66+
)
67+
input_indices = ((output_indices + offset) * scale).to(torch.int64)
68+
for _ in range(num_spatial_dims - 1 - d):
69+
input_indices = input_indices.unsqueeze(-1)
70+
indices.append(input_indices)
71+
return tuple(indices)
72+
73+
74+
# Fork from pytorch/torch/_decomp/decompositions.py
75+
def _upsample_nearest2d_common(input, h_indices, w_indices):
76+
result = torch.ops.aten.index(input, (None, None, h_indices, w_indices))
77+
result = result.contiguous()
78+
return result
79+
80+
2181
fx_infra.decomp.update_pre_lower_decomp(
2282
torch._decomp.get_decompositions([
23-
torch.ops.aten.upsample_nearest2d,
2483
torch.ops.aten._native_batch_norm_legit.no_stats,
2584
torch.ops.aten._native_batch_norm_legit_functional,
2685
torch.ops.aten._adaptive_avg_pool2d,
@@ -35,11 +94,44 @@
3594
torch.ops.aten.replication_pad2d,
3695
torch.ops.aten.replication_pad3d,
3796
torch.ops.aten.upsample_bilinear2d.vec,
38-
torch.ops.aten.upsample_nearest2d.vec,
3997
torch.ops.aten.addmm,
4098
])
4199
)
42100

101+
102+
@functools.partial(
103+
fx_infra.decomp.add_pre_lower_decomp,
104+
torch.ops.aten.upsample_nearest2d.default,
105+
)
106+
@fx_infra.annotate_force_decomp
107+
def upsample_nearest2d(input, output_size, scales_h=None, scales_w=None):
108+
h_indices, w_indices = _compute_upsample_nearest_indices(
109+
input, output_size, (scales_h, scales_w)
110+
)
111+
return _upsample_nearest2d_common(input, h_indices, w_indices)
112+
113+
114+
def get_scale_value(scales, idx):
115+
if scales is None:
116+
return None
117+
return scales[idx]
118+
119+
120+
@functools.partial(
121+
fx_infra.decomp.add_pre_lower_decomp,
122+
torch.ops.aten.upsample_nearest2d.vec,
123+
)
124+
@fx_infra.annotate_force_decomp
125+
def upsample_nearest2d_vec(input, output_size, scale_factors):
126+
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
127+
scale_h = get_scale_value(scale_factors, 0)
128+
scale_w = get_scale_value(scale_factors, 1)
129+
130+
return torch.ops.aten.upsample_nearest2d.default(
131+
input, osize, scale_h, scale_w
132+
)
133+
134+
43135
fx_infra.decomp.remove_pre_lower_decomp(torch.ops.aten.roll)
44136

45137
# Torch's default einsum impl/decompositions is less efficient and

0 commit comments

Comments
 (0)