diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index ac79024acfb9..95b0e05361aa 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -1383,7 +1383,7 @@ def _process_derived_symbol( def create_input_vars( self, exported_program: torch.export.ExportedProgram - ) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var], Dict[str, Tuple[int, int]]]: + ) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var], Dict[str, Tuple[int, Optional[int]]]]: """Create relax input vars.""" parameters_buffers_constants = OrderedDict() user_inputs = OrderedDict() @@ -1391,11 +1391,16 @@ def create_input_vars( range_constraints = {} if hasattr(exported_program, "range_constraints"): + import math + for symbol, value_range in exported_program.range_constraints.items(): if hasattr(value_range, "lower") and hasattr(value_range, "upper"): try: + # PyTorch uses int_oo (IntInfinity) for unbounded constraints lower = int(value_range.lower) - upper = int(value_range.upper) + upper = ( + None if math.isinf(float(value_range.upper)) else int(value_range.upper) + ) symbol_name, _ = self._process_derived_symbol( symbol, torch_symbol_to_relax_var @@ -1472,10 +1477,16 @@ def from_exported_program( func_attrs["tir_var_lower_bound"] = { var_name: lower for var_name, (lower, _) in range_constraints.items() } - func_attrs["tir_var_upper_bound"] = { - var_name: upper for var_name, (_, upper) in range_constraints.items() + + upper_bounds = { + var_name: upper + for var_name, (_, upper) in range_constraints.items() + if upper is not None } + if upper_bounds: + func_attrs["tir_var_upper_bound"] = upper_bounds + nodes: List[fx.Node] = exported_program.graph.nodes # Find all the missing function types diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 78a8a09a3cf4..d4c23bfdd5d0 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -7206,6 +7206,7 @@ def main( lhs: R.Tensor((B, 4), dtype="float32"), rhs: R.Tensor((B, 4), dtype="float32"), ) -> R.Tuple(R.Tensor((B, 4), dtype="float32")): + R.func_attr({"tir_var_lower_bound": {"s0": 0}}) with R.dataflow(): lv: R.Tensor((B, 4), dtype="float32") = R.add(lhs, rhs) gv: R.Tuple(R.Tensor((B, 4), dtype="float32")) = (lv,) @@ -7909,6 +7910,34 @@ def main( tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True) +def test_dynamic_shape_with_unbounded_constraints(): + class DynamicModel(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.add.Tensor(x, x) + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor(("s0", 4), dtype="float32") + ) -> R.Tuple(R.Tensor(("s0", 4), dtype="float32")): + s0 = T.int64(is_size_var=True) + R.func_attr({"tir_var_lower_bound": {"s0": 2}}) + with R.dataflow(): + lv: R.Tensor((s0, 4), dtype="float32") = R.add(x, x) + gv: R.Tuple(R.Tensor((s0, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(8, 4),) + batch = torch.export.Dim("batch", min=2) + dynamic_shapes = {"x": {0: batch}} + exported_program = export(DynamicModel(), args=example_args, dynamic_shapes=dynamic_shapes) + + mod = from_exported_program(exported_program) + tvm.ir.assert_structural_equal(mod, Expected) + + def test_sym_size_int(): class SymSizeInt(Module): def __init__(self, dim): @@ -7955,6 +7984,7 @@ def main( x: R.Tensor(("s0", 3, 4), dtype="float32") ) -> R.Tuple(R.Tensor(("s0", 12), dtype="float32")): s0 = T.int64(is_size_var=True) + R.func_attr({"tir_var_lower_bound": {"s0": 0}}) with R.dataflow(): lv: R.Tensor((s0, 12), dtype="float32") = R.reshape(x, R.shape([s0, 12])) gv: R.Tuple(R.Tensor((s0, 12), dtype="float32")) = (lv,)