Skip to content

Commit ea89f21

Browse files
authored
[Relax][PyTorch] Support advanced range constraints (addition) (#18452)
## Related Issue - #17818 ## Why - Add support for addition expressions (e.g., s0 + 1) in PyTorch dynamic shape constraints ## How - Parse `SymPy` addition expressions from PyTorch's range_constraints
1 parent 0701aab commit ea89f21

File tree

2 files changed

+75
-5
lines changed

2 files changed

+75
-5
lines changed

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
"""PyTorch ExportedProgram of Relax."""
2121
from collections import ChainMap, OrderedDict
2222
from functools import partial
23-
from typing import Callable, Dict, List, Tuple
23+
from typing import Callable, Dict, List, Optional, Tuple
2424

2525
import torch
2626
import tvm
@@ -1181,6 +1181,40 @@ def create_convert_map(
11811181
"_local_scalar_dense.default": self._item,
11821182
}
11831183

1184+
def _process_derived_symbol(
1185+
self, symbol, torch_symbol_to_relax_var: Dict[str, tvm.tir.Var]
1186+
) -> Tuple[str, Optional[tvm.tir.PrimExpr]]:
1187+
"""Process a sympy symbol to generate a descriptive name and TIR expression."""
1188+
import sympy
1189+
1190+
if isinstance(symbol, sympy.Symbol):
1191+
return str(symbol), None
1192+
1193+
if not isinstance(symbol, sympy.Add):
1194+
return str(symbol), None
1195+
1196+
tir_expr = None
1197+
for arg in symbol.args:
1198+
if isinstance(arg, sympy.Integer):
1199+
term = tvm.tir.IntImm("int64", int(arg))
1200+
elif isinstance(arg, sympy.Symbol):
1201+
term = torch_symbol_to_relax_var.setdefault(
1202+
str(arg), tvm.tir.SizeVar(str(arg), "int64")
1203+
)
1204+
else:
1205+
_, term = self._process_derived_symbol(arg, torch_symbol_to_relax_var)
1206+
1207+
if term is None:
1208+
return str(symbol), None
1209+
tir_expr = term if tir_expr is None else tir_expr + term
1210+
1211+
if isinstance(tir_expr, tvm.tir.Add):
1212+
for const, var in [(tir_expr.a, tir_expr.b), (tir_expr.b, tir_expr.a)]:
1213+
if isinstance(const, tvm.tir.IntImm) and isinstance(var, tvm.tir.Var):
1214+
return f"{var.name}___{const.value}", tir_expr
1215+
1216+
return str(symbol), tir_expr
1217+
11841218
def create_input_vars(
11851219
self, exported_program: torch.export.ExportedProgram
11861220
) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var], Dict[str, Tuple[int, int]]]:
@@ -1192,12 +1226,16 @@ def create_input_vars(
11921226

11931227
if hasattr(exported_program, "range_constraints"):
11941228
for symbol, value_range in exported_program.range_constraints.items():
1195-
symbol_name = str(symbol)
11961229
if hasattr(value_range, "lower") and hasattr(value_range, "upper"):
11971230
try:
11981231
lower = int(value_range.lower)
11991232
upper = int(value_range.upper)
1233+
1234+
symbol_name, _ = self._process_derived_symbol(
1235+
symbol, torch_symbol_to_relax_var
1236+
)
12001237
range_constraints[symbol_name] = (lower, upper)
1238+
12011239
except (OverflowError, AttributeError, TypeError):
12021240
continue
12031241

@@ -1255,10 +1293,8 @@ def from_exported_program(
12551293
# Initialize the block builder with a function and a dataflow block.
12561294
self.block_builder = relax.BlockBuilder()
12571295
func_name = "main"
1258-
func_attrs = {"num_input": len(user_input_vars)} if keep_params_as_input else None
1296+
func_attrs = {"num_input": len(user_input_vars)} if keep_params_as_input else {}
12591297
if range_constraints:
1260-
if func_attrs is None:
1261-
func_attrs = {}
12621298
func_attrs["tir_var_lower_bound"] = {
12631299
var_name: lower for var_name, (lower, _) in range_constraints.items()
12641300
}

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7000,5 +7000,39 @@ def main(
70007000
tvm.ir.assert_structural_equal(mod, Expected)
70017001

70027002

7003+
def test_dynamic_shape_with_derived_range_constraints():
7004+
class ConcatModel(torch.nn.Module):
7005+
def forward(self, x, y):
7006+
return torch.cat([x, y], dim=0)
7007+
7008+
@I.ir_module
7009+
class Expected:
7010+
@R.function
7011+
def main(
7012+
x: R.Tensor(("s0", 4), dtype="float32"), y: R.Tensor(("s0___1", 4), dtype="float32")
7013+
) -> R.Tuple(R.Tensor(("s0 + s0___1", 4), dtype="float32")):
7014+
s0 = T.int64(is_size_var=True)
7015+
s0___1 = T.int64(is_size_var=True)
7016+
R.func_attr(
7017+
{
7018+
"tir_var_lower_bound": {"s0": 1, "s0___1": 2},
7019+
"tir_var_upper_bound": {"s0": 64, "s0___1": 65},
7020+
}
7021+
)
7022+
with R.dataflow():
7023+
lv: R.Tensor((s0 + s0___1, 4), dtype="float32") = R.concat((x, y), axis=0)
7024+
gv: R.Tuple(R.Tensor((s0 + s0___1, 4), dtype="float32")) = (lv,)
7025+
R.output(gv)
7026+
return gv
7027+
7028+
batch = torch.export.Dim("batch", min=1, max=64)
7029+
example_args = (torch.randn(8, 4), torch.randn(9, 4))
7030+
dynamic_shapes = {"x": {0: batch}, "y": {0: batch + 1}}
7031+
exported_program = export(ConcatModel(), args=example_args, dynamic_shapes=dynamic_shapes)
7032+
7033+
mod = from_exported_program(exported_program, run_ep_decomposition=True)
7034+
tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True)
7035+
7036+
70037037
if __name__ == "__main__":
70047038
tvm.testing.main()

0 commit comments

Comments
 (0)