From 296def5261493065f4d6cbe0d2724f647ac9dc45 Mon Sep 17 00:00:00 2001 From: Google AI Edge Date: Fri, 9 Jan 2026 02:27:47 -0800 Subject: [PATCH] Use standard PyTorch decomposition for `upsample_nearest2d`. Replaced custom implementations of `upsample_nearest2d` decompositions with the standard decomposition provided by PyTorch. This aligns with upstream practices and reduces maintenance overhead. PiperOrigin-RevId: 854101690 --- .../odml_torch/lowerings/_decomp_registry.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/ai_edge_torch/odml_torch/lowerings/_decomp_registry.py b/ai_edge_torch/odml_torch/lowerings/_decomp_registry.py index cd9cd6104..20eca5296 100644 --- a/ai_edge_torch/odml_torch/lowerings/_decomp_registry.py +++ b/ai_edge_torch/odml_torch/lowerings/_decomp_registry.py @@ -95,6 +95,7 @@ def _upsample_nearest2d_common(input, h_indices, w_indices): torch.ops.aten.replication_pad3d, torch.ops.aten.upsample_bilinear2d.vec, torch.ops.aten.addmm, + torch.ops.aten.upsample_nearest2d.vec, ]) ) @@ -117,19 +118,7 @@ def get_scale_value(scales, idx): return scales[idx] -@functools.partial( - fx_infra.decomp.add_pre_lower_decomp, - torch.ops.aten.upsample_nearest2d.vec, -) -@fx_infra.annotate_force_decomp -def upsample_nearest2d_vec(input, output_size, scale_factors): - osize = upsample_compute_output_size(input.size(), output_size, scale_factors) - scale_h = get_scale_value(scale_factors, 0) - scale_w = get_scale_value(scale_factors, 1) - return torch.ops.aten.upsample_nearest2d.default( - input, osize, scale_h, scale_w - ) fx_infra.decomp.remove_pre_lower_decomp(torch.ops.aten.roll)