Skip to content

Commit 6041e9f

Browse files
authored
[Relax][PyTorch] Add support for antialiased bilinear upsampling (#18500)
## Related Issue closes #18365 ## How - add support for antialiased bilinear upsampling
1 parent ec7f59f commit 6041e9f

File tree

2 files changed

+54
-0
lines changed

2 files changed

+54
-0
lines changed

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,22 @@ def _upsample_bilinear2d(self, node: fx.Node) -> relax.Var:
298298
x, size=size, scale_factor=scale_factor, method="linear", align_corners=align_corners
299299
)
300300

301+
def _upsample_bilinear2d_aa(self, node: fx.Node) -> relax.Var:
302+
x = self.env[node.args[0]]
303+
size = node.args[1] if len(node.args) > 1 else node.kwargs.get("output_size", None)
304+
align_corners = (
305+
node.args[2] if len(node.args) > 2 else node.kwargs.get("align_corners", False)
306+
)
307+
scale_factor = (
308+
node.args[3] if len(node.args) > 3 else node.kwargs.get("scale_factors", None)
309+
)
310+
311+
# Note: TVM's resize2d doesn't have explicit antialias support.
312+
# For upsampling, antialiasing has minimal effect, so we use regular bilinear.
313+
return self._upsample_impl(
314+
x, size=size, scale_factor=scale_factor, method="linear", align_corners=align_corners
315+
)
316+
301317
def _upsample_nearest2d(self, node: fx.node) -> relax.Var:
302318
x = self.env[node.args[0]]
303319
size = node.args[1] if len(node.args) > 1 else node.kwargs.get("size", None)
@@ -1218,6 +1234,7 @@ def create_convert_map(
12181234
"scaled_dot_product_attention.default": self._scaled_dot_product_attention,
12191235
"unbind.int": self._unbind,
12201236
"upsample_bilinear2d.vec": self._upsample_bilinear2d,
1237+
"_upsample_bilinear2d_aa.default": self._upsample_bilinear2d_aa,
12211238
"upsample_nearest2d.vec": self._upsample_nearest2d,
12221239
"upsample_bicubic2d.vec": self._upsample_bicubic2d,
12231240
# statistical

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4703,6 +4703,43 @@ def main(
47034703
verify_model(InterpolateBicubic(), example_args, {}, expected_bicubic)
47044704

47054705

4706+
def test_interpolate_antialiased():
4707+
"""Test bilinear interpolation with antialiasing enabled."""
4708+
4709+
class InterpolateBilinearAA(Module):
4710+
def forward(self, input):
4711+
return torch.nn.functional.interpolate(
4712+
input, size=(64, 64), mode="bilinear", align_corners=False, antialias=True
4713+
)
4714+
4715+
@tvm.script.ir_module
4716+
class expected_bilinear_aa:
4717+
@R.function
4718+
def main(
4719+
input: R.Tensor((1, 3, 32, 32), dtype="float32")
4720+
) -> R.Tuple(R.Tensor((1, 3, 64, 64), dtype="float32")):
4721+
with R.dataflow():
4722+
lv: R.Tensor((1, 3, 64, 64), dtype="float32") = R.image.resize2d(
4723+
input,
4724+
R.shape([64, 64]),
4725+
roi=[T.float32(0.0), T.float32(0.0), T.float32(0.0), T.float32(0.0)],
4726+
layout="NCHW",
4727+
method="linear",
4728+
coordinate_transformation_mode="half_pixel",
4729+
rounding_method="round",
4730+
cubic_alpha=-0.75,
4731+
cubic_exclude=0,
4732+
extrapolation_value=0.0,
4733+
out_dtype="void",
4734+
)
4735+
gv: R.Tuple(R.Tensor((1, 3, 64, 64), dtype="float32")) = (lv,)
4736+
R.output(gv)
4737+
return gv
4738+
4739+
example_args = (torch.randn(1, 3, 32, 32, dtype=torch.float32),)
4740+
verify_model(InterpolateBilinearAA(), example_args, {}, expected_bilinear_aa)
4741+
4742+
47064743
def test_mean():
47074744
class Mean(Module):
47084745
def forward(self, input):

0 commit comments

Comments
 (0)