Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions python/tvm/relax/frontend/torch/exported_program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1383,19 +1383,24 @@ def _process_derived_symbol(

def create_input_vars(
self, exported_program: torch.export.ExportedProgram
) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var], Dict[str, Tuple[int, int]]]:
) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var], Dict[str, Tuple[int, Optional[int]]]]:
"""Create relax input vars."""
parameters_buffers_constants = OrderedDict()
user_inputs = OrderedDict()
torch_symbol_to_relax_var: Dict[str, tvm.tir.Var] = {}
range_constraints = {}

if hasattr(exported_program, "range_constraints"):
import math

for symbol, value_range in exported_program.range_constraints.items():
if hasattr(value_range, "lower") and hasattr(value_range, "upper"):
try:
# PyTorch uses int_oo (IntInfinity) for unbounded constraints
lower = int(value_range.lower)
upper = int(value_range.upper)
upper = (
None if math.isinf(float(value_range.upper)) else int(value_range.upper)
)

symbol_name, _ = self._process_derived_symbol(
symbol, torch_symbol_to_relax_var
Expand Down Expand Up @@ -1469,13 +1474,22 @@ def from_exported_program(
func_name = "main"
func_attrs = {"num_input": len(user_input_vars)} if keep_params_as_input else {}
if range_constraints:
func_attrs["tir_var_lower_bound"] = {
var_name: lower for var_name, (lower, _) in range_constraints.items()
lower_bounds = {
var_name: lower for var_name, (lower, _) in range_constraints.items() if lower != 0
}
func_attrs["tir_var_upper_bound"] = {
var_name: upper for var_name, (_, upper) in range_constraints.items()

upper_bounds = {
var_name: upper
for var_name, (_, upper) in range_constraints.items()
if upper is not None
}

if upper_bounds:
func_attrs["tir_var_upper_bound"] = upper_bounds

if lower_bounds:
func_attrs["tir_var_lower_bound"] = lower_bounds

nodes: List[fx.Node] = exported_program.graph.nodes

# Find all the missing function types
Expand Down
28 changes: 28 additions & 0 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -7909,6 +7909,34 @@ def main(
tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True)


def test_dynamic_shape_with_unbounded_constraints():
class DynamicModel(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.add.Tensor(x, x)

@I.ir_module
class Expected:
@R.function
def main(
x: R.Tensor(("s0", 4), dtype="float32")
) -> R.Tuple(R.Tensor(("s0", 4), dtype="float32")):
s0 = T.int64(is_size_var=True)
R.func_attr({"tir_var_lower_bound": {"s0": 2}})
with R.dataflow():
lv: R.Tensor((s0, 4), dtype="float32") = R.add(x, x)
gv: R.Tuple(R.Tensor((s0, 4), dtype="float32")) = (lv,)
R.output(gv)
return gv

example_args = (torch.randn(8, 4),)
batch = torch.export.Dim("batch", min=2)
dynamic_shapes = {"x": {0: batch}}
exported_program = export(DynamicModel(), args=example_args, dynamic_shapes=dynamic_shapes)

mod = from_exported_program(exported_program)
tvm.ir.assert_structural_equal(mod, Expected)


def test_sym_size_int():
class SymSizeInt(Module):
def __init__(self, dim):
Expand Down
Loading