diff --git a/ai_edge_torch/odml_torch/lowerings/_decomp_registry.py b/ai_edge_torch/odml_torch/lowerings/_decomp_registry.py index cd9cd6104..20eca5296 100644 --- a/ai_edge_torch/odml_torch/lowerings/_decomp_registry.py +++ b/ai_edge_torch/odml_torch/lowerings/_decomp_registry.py @@ -95,6 +95,7 @@ def _upsample_nearest2d_common(input, h_indices, w_indices): torch.ops.aten.replication_pad3d, torch.ops.aten.upsample_bilinear2d.vec, torch.ops.aten.addmm, + torch.ops.aten.upsample_nearest2d.vec, ]) ) @@ -117,19 +118,7 @@ def get_scale_value(scales, idx): return scales[idx] -@functools.partial( - fx_infra.decomp.add_pre_lower_decomp, - torch.ops.aten.upsample_nearest2d.vec, -) -@fx_infra.annotate_force_decomp -def upsample_nearest2d_vec(input, output_size, scale_factors): - osize = upsample_compute_output_size(input.size(), output_size, scale_factors) - scale_h = get_scale_value(scale_factors, 0) - scale_w = get_scale_value(scale_factors, 1) - return torch.ops.aten.upsample_nearest2d.default( - input, osize, scale_h, scale_w - ) fx_infra.decomp.remove_pre_lower_decomp(torch.ops.aten.roll)