Skip to content

Commit 1f42040

Browse files
ai-edge-botcopybara-github
authored andcommitted
Use standard PyTorch decomposition for upsample_nearest2d.
Replaced custom implementations of `upsample_nearest2d` decompositions with the standard decomposition provided by PyTorch. This aligns with upstream practices and reduces maintenance overhead. PiperOrigin-RevId: 854101690
1 parent 56b4375 commit 1f42040

File tree

1 file changed

+1
-12
lines changed

1 file changed

+1
-12
lines changed

ai_edge_torch/odml_torch/lowerings/_decomp_registry.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def _upsample_nearest2d_common(input, h_indices, w_indices):
9595
torch.ops.aten.replication_pad3d,
9696
torch.ops.aten.upsample_bilinear2d.vec,
9797
torch.ops.aten.addmm,
98+
torch.ops.aten.upsample_nearest2d.vec,
9899
])
99100
)
100101

@@ -117,19 +118,7 @@ def get_scale_value(scales, idx):
117118
return scales[idx]
118119

119120

120-
@functools.partial(
121-
fx_infra.decomp.add_pre_lower_decomp,
122-
torch.ops.aten.upsample_nearest2d.vec,
123-
)
124-
@fx_infra.annotate_force_decomp
125-
def upsample_nearest2d_vec(input, output_size, scale_factors):
126-
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
127-
scale_h = get_scale_value(scale_factors, 0)
128-
scale_w = get_scale_value(scale_factors, 1)
129121

130-
return torch.ops.aten.upsample_nearest2d.default(
131-
input, osize, scale_h, scale_w
132-
)
133122

134123

135124
fx_infra.decomp.remove_pre_lower_decomp(torch.ops.aten.roll)

0 commit comments

Comments
 (0)