Skip to content

Commit 49973d1

Browse files
authored
[Relax][PyTorch] Support advanced range constraints (multiplication) (#18463)
## Related Issue - #17818 ## Why - Add support for multiplication expressions (e.g., s0 * 2) in PyTorch dynamic shape constraints ## How - Parse `SymPy` multiplication expressions from PyTorch's range_constraints
1 parent 83db389 commit 49973d1

File tree

2 files changed

+96
-9
lines changed

2 files changed

+96
-9
lines changed

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1191,7 +1191,7 @@ def _process_derived_symbol(
11911191
if isinstance(symbol, sympy.Symbol):
11921192
return str(symbol), None
11931193

1194-
if not isinstance(symbol, sympy.Add):
1194+
if not isinstance(symbol, (sympy.Add, sympy.Mul)):
11951195
return str(symbol), None
11961196

11971197
tir_expr = None
@@ -1207,13 +1207,24 @@ def _process_derived_symbol(
12071207

12081208
if term is None:
12091209
return str(symbol), None
1210-
tir_expr = term if tir_expr is None else tir_expr + term
1210+
1211+
if tir_expr is None:
1212+
tir_expr = term
1213+
elif isinstance(symbol, sympy.Mul):
1214+
tir_expr = tir_expr * term
1215+
elif isinstance(symbol, sympy.Add):
1216+
tir_expr = tir_expr + term
12111217

12121218
if isinstance(tir_expr, tvm.tir.Add):
12131219
for const, var in [(tir_expr.a, tir_expr.b), (tir_expr.b, tir_expr.a)]:
12141220
if isinstance(const, tvm.tir.IntImm) and isinstance(var, tvm.tir.Var):
12151221
return f"{var.name}___{const.value}", tir_expr
12161222

1223+
if isinstance(tir_expr, tvm.tir.Mul):
1224+
for const, var in [(tir_expr.a, tir_expr.b), (tir_expr.b, tir_expr.a)]:
1225+
if isinstance(const, tvm.tir.IntImm) and isinstance(var, tvm.tir.Var):
1226+
return f"{var.name}_{const.value}", tir_expr
1227+
12171228
return str(symbol), tir_expr
12181229

12191230
def create_input_vars(
@@ -1256,12 +1267,20 @@ def create_input_vars(
12561267
torch_shape = exported_program.state_dict[spec.target].shape
12571268
torch_dtype = exported_program.state_dict[spec.target].dtype
12581269

1259-
relax_shape = [
1260-
torch_symbol_to_relax_var.setdefault(str(s), tvm.tir.SizeVar(str(s), "int64"))
1261-
if isinstance(s, torch.SymInt)
1262-
else s
1263-
for s in torch_shape
1264-
]
1270+
relax_shape = []
1271+
for s in torch_shape:
1272+
if isinstance(s, torch.SymInt):
1273+
sympy_node = s.node.expr if hasattr(s.node, "expr") else s.node
1274+
symbol_name, _ = self._process_derived_symbol(
1275+
sympy_node, torch_symbol_to_relax_var
1276+
)
1277+
1278+
size_var = torch_symbol_to_relax_var.setdefault(
1279+
symbol_name, tvm.tir.SizeVar(symbol_name, "int64")
1280+
)
1281+
relax_shape.append(size_var)
1282+
else:
1283+
relax_shape.append(s)
12651284
dtype = self._convert_data_type(torch_dtype)
12661285

12671286
relax_var = relax.Var(name_hint, relax.TensorStructInfo(relax_shape, dtype))

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7028,7 +7028,7 @@ def main(
70287028
tvm.ir.assert_structural_equal(mod, Expected)
70297029

70307030

7031-
def test_dynamic_shape_with_derived_range_constraints():
7031+
def test_dynamic_shape_with_addition_constraints():
70327032
class ConcatModel(torch.nn.Module):
70337033
def forward(self, x, y):
70347034
return torch.cat([x, y], dim=0)
@@ -7062,5 +7062,73 @@ def main(
70627062
tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True)
70637063

70647064

7065+
def test_dynamic_shape_with_subtraction_constraints():
7066+
class ConcatModel(torch.nn.Module):
7067+
def forward(self, x, y):
7068+
return torch.cat([x, y], dim=0)
7069+
7070+
@I.ir_module
7071+
class Expected:
7072+
@R.function
7073+
def main(
7074+
x: R.Tensor(("s1___1", 4), dtype="float32"), y: R.Tensor(("s1", 4), dtype="float32")
7075+
) -> R.Tuple(R.Tensor(("s1___1 + s1", 4), dtype="float32")):
7076+
s1___1 = T.int64(is_size_var=True)
7077+
s1 = T.int64(is_size_var=True)
7078+
R.func_attr(
7079+
{
7080+
"tir_var_lower_bound": {"s1": 0, "s1___1": 1},
7081+
"tir_var_upper_bound": {"s1": 63, "s1___1": 64},
7082+
}
7083+
)
7084+
with R.dataflow():
7085+
lv: R.Tensor((s1___1 + s1, 4), dtype="float32") = R.concat((x, y), axis=0)
7086+
gv: R.Tuple(R.Tensor((s1___1 + s1, 4), dtype="float32")) = (lv,)
7087+
R.output(gv)
7088+
return gv
7089+
7090+
batch = torch.export.Dim("batch", min=1, max=64)
7091+
example_args = (torch.randn(8, 4), torch.randn(7, 4))
7092+
dynamic_shapes = {"x": {0: batch}, "y": {0: batch - 1}}
7093+
exported_program = export(ConcatModel(), args=example_args, dynamic_shapes=dynamic_shapes)
7094+
7095+
mod = from_exported_program(exported_program, run_ep_decomposition=True)
7096+
tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True)
7097+
7098+
7099+
def test_dynamic_shape_with_multiplication_constraints():
7100+
class ConcatModel(torch.nn.Module):
7101+
def forward(self, x, y):
7102+
return torch.cat([x, y], dim=0)
7103+
7104+
@I.ir_module
7105+
class Expected:
7106+
@R.function
7107+
def main(
7108+
x: R.Tensor(("s0", 4), dtype="float32"), y: R.Tensor(("s0_2", 4), dtype="float32")
7109+
) -> R.Tuple(R.Tensor(("s0 + s0_2", 4), dtype="float32")):
7110+
s0 = T.int64(is_size_var=True)
7111+
s0_2 = T.int64(is_size_var=True)
7112+
R.func_attr(
7113+
{
7114+
"tir_var_lower_bound": {"s0": 1, "s0_2": 2},
7115+
"tir_var_upper_bound": {"s0": 64, "s0_2": 128},
7116+
}
7117+
)
7118+
with R.dataflow():
7119+
lv: R.Tensor((s0 + s0_2, 4), dtype="float32") = R.concat((x, y), axis=0)
7120+
gv: R.Tuple(R.Tensor((s0 + s0_2, 4), dtype="float32")) = (lv,)
7121+
R.output(gv)
7122+
return gv
7123+
7124+
batch = torch.export.Dim("batch", min=1, max=64)
7125+
example_args = (torch.randn(8, 4), torch.randn(16, 4))
7126+
dynamic_shapes = {"x": {0: batch}, "y": {0: batch * 2}}
7127+
exported_program = export(ConcatModel(), args=example_args, dynamic_shapes=dynamic_shapes)
7128+
7129+
mod = from_exported_program(exported_program, run_ep_decomposition=True)
7130+
tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True)
7131+
7132+
70657133
if __name__ == "__main__":
70667134
tvm.testing.main()

0 commit comments

Comments
 (0)