Skip to content

Commit 161049e

Browse files
authored
[Relax][PyTorch] Enhance handling of unbounded upper bound constraints (#18489)
## Why PyTorch uses int_oo (IntInfinity) for unbounded constraints, which would make our current implemenation crash ## How - Update the type hint for `create_input_vars` to allow for optional upper bounds. - Modify the logic to handle unbounded constraints by setting upper bounds to None when applicable. - Add a new test case
1 parent ced7181 commit 161049e

File tree

2 files changed

+45
-4
lines changed

2 files changed

+45
-4
lines changed

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1383,19 +1383,24 @@ def _process_derived_symbol(
13831383

13841384
def create_input_vars(
13851385
self, exported_program: torch.export.ExportedProgram
1386-
) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var], Dict[str, Tuple[int, int]]]:
1386+
) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var], Dict[str, Tuple[int, Optional[int]]]]:
13871387
"""Create relax input vars."""
13881388
parameters_buffers_constants = OrderedDict()
13891389
user_inputs = OrderedDict()
13901390
torch_symbol_to_relax_var: Dict[str, tvm.tir.Var] = {}
13911391
range_constraints = {}
13921392

13931393
if hasattr(exported_program, "range_constraints"):
1394+
import math
1395+
13941396
for symbol, value_range in exported_program.range_constraints.items():
13951397
if hasattr(value_range, "lower") and hasattr(value_range, "upper"):
13961398
try:
1399+
# PyTorch uses int_oo (IntInfinity) for unbounded constraints
13971400
lower = int(value_range.lower)
1398-
upper = int(value_range.upper)
1401+
upper = (
1402+
None if math.isinf(float(value_range.upper)) else int(value_range.upper)
1403+
)
13991404

14001405
symbol_name, _ = self._process_derived_symbol(
14011406
symbol, torch_symbol_to_relax_var
@@ -1472,10 +1477,16 @@ def from_exported_program(
14721477
func_attrs["tir_var_lower_bound"] = {
14731478
var_name: lower for var_name, (lower, _) in range_constraints.items()
14741479
}
1475-
func_attrs["tir_var_upper_bound"] = {
1476-
var_name: upper for var_name, (_, upper) in range_constraints.items()
1480+
1481+
upper_bounds = {
1482+
var_name: upper
1483+
for var_name, (_, upper) in range_constraints.items()
1484+
if upper is not None
14771485
}
14781486

1487+
if upper_bounds:
1488+
func_attrs["tir_var_upper_bound"] = upper_bounds
1489+
14791490
nodes: List[fx.Node] = exported_program.graph.nodes
14801491

14811492
# Find all the missing function types

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7206,6 +7206,7 @@ def main(
72067206
lhs: R.Tensor((B, 4), dtype="float32"),
72077207
rhs: R.Tensor((B, 4), dtype="float32"),
72087208
) -> R.Tuple(R.Tensor((B, 4), dtype="float32")):
7209+
R.func_attr({"tir_var_lower_bound": {"s0": 0}})
72097210
with R.dataflow():
72107211
lv: R.Tensor((B, 4), dtype="float32") = R.add(lhs, rhs)
72117212
gv: R.Tuple(R.Tensor((B, 4), dtype="float32")) = (lv,)
@@ -7909,6 +7910,34 @@ def main(
79097910
tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True)
79107911

79117912

7913+
def test_dynamic_shape_with_unbounded_constraints():
7914+
class DynamicModel(torch.nn.Module):
7915+
def forward(self, x):
7916+
return torch.ops.aten.add.Tensor(x, x)
7917+
7918+
@I.ir_module
7919+
class Expected:
7920+
@R.function
7921+
def main(
7922+
x: R.Tensor(("s0", 4), dtype="float32")
7923+
) -> R.Tuple(R.Tensor(("s0", 4), dtype="float32")):
7924+
s0 = T.int64(is_size_var=True)
7925+
R.func_attr({"tir_var_lower_bound": {"s0": 2}})
7926+
with R.dataflow():
7927+
lv: R.Tensor((s0, 4), dtype="float32") = R.add(x, x)
7928+
gv: R.Tuple(R.Tensor((s0, 4), dtype="float32")) = (lv,)
7929+
R.output(gv)
7930+
return gv
7931+
7932+
example_args = (torch.randn(8, 4),)
7933+
batch = torch.export.Dim("batch", min=2)
7934+
dynamic_shapes = {"x": {0: batch}}
7935+
exported_program = export(DynamicModel(), args=example_args, dynamic_shapes=dynamic_shapes)
7936+
7937+
mod = from_exported_program(exported_program)
7938+
tvm.ir.assert_structural_equal(mod, Expected)
7939+
7940+
79127941
def test_sym_size_int():
79137942
class SymSizeInt(Module):
79147943
def __init__(self, dim):
@@ -7955,6 +7984,7 @@ def main(
79557984
x: R.Tensor(("s0", 3, 4), dtype="float32")
79567985
) -> R.Tuple(R.Tensor(("s0", 12), dtype="float32")):
79577986
s0 = T.int64(is_size_var=True)
7987+
R.func_attr({"tir_var_lower_bound": {"s0": 0}})
79587988
with R.dataflow():
79597989
lv: R.Tensor((s0, 12), dtype="float32") = R.reshape(x, R.shape([s0, 12]))
79607990
gv: R.Tuple(R.Tensor((s0, 12), dtype="float32")) = (lv,)

0 commit comments

Comments
 (0)