Skip to content

Commit 4a60bb2

Browse files
authored
[Relax][PyTorch] Add support for torch.ops.aten.sym_size.int in ExportedProgram frontend (#18473)
As per title. cc @tlopex
1 parent c8515e1 commit 4a60bb2

File tree

4 files changed

+38
-6
lines changed

4 files changed

+38
-6
lines changed

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2357,6 +2357,12 @@ def _item(self, node: fx.Node) -> relax.Var:
23572357
x = self.env[node.args[0]]
23582358
return self.block_builder.emit(relax.op.take(x, relax.const(0, "int64"), axis=0))
23592359

2360+
def _sym_size_int(self, node: fx.Node) -> relax.Expr:
2361+
x = self.env[node.args[0]]
2362+
shape = self.shape_of(x)
2363+
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
2364+
return self.block_builder.emit(relax.const(int(shape[dim]), "int32"))
2365+
23602366
def _zeros_inplace(self, node: fx.Node) -> relax.Var:
23612367
x = self.env[node.args[0]]
23622368
output = self.block_builder.emit(relax.op.zeros_like(x))

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,6 +1189,7 @@ def create_convert_map(
11891189
# other
11901190
"getitem": self._getitem,
11911191
"item.default": self._item,
1192+
"sym_size.int": self._sym_size_int,
11921193
"_local_scalar_dense.default": self._item,
11931194
}
11941195

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -730,12 +730,6 @@ def _getattr(self, node: fx.Node) -> relax.Var:
730730
return self.shape_of(self.env[node.args[0]])
731731
return getattr(self.env[node.args[0]], node.args[1])
732732

733-
def _sym_size_int(self, node: fx.Node) -> relax.Expr:
734-
x = self.env[node.args[0]]
735-
shape = self.shape_of(x)
736-
idx = node.args[1]
737-
return self.block_builder.emit(relax.const(shape[idx].value, "int32"))
738-
739733
def create_input_vars(self, input_info: List[Tuple[Tuple[int], str]]) -> List[relax.Var]:
740734
inputs = list()
741735
for idx, (shape, dtype) in enumerate(input_info):

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7508,5 +7508,36 @@ def main(
75087508
tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True)
75097509

75107510

7511+
def test_sym_size_int():
7512+
class SymSizeInt(Module):
7513+
def __init__(self, dim):
7514+
super().__init__()
7515+
self.dim = dim
7516+
7517+
def forward(self, x):
7518+
# TODO(@mshr-h): `torch.ops.aten.sym_size.int(x, self.dim)` would be ideal, but currently
7519+
# the ep frontend is not able to handle it.
7520+
return torch.add(x[0], torch.ops.aten.sym_size.int(x, self.dim))
7521+
7522+
@I.ir_module
7523+
class Expected:
7524+
@R.function
7525+
def main(
7526+
x: R.Tensor((1, 3, 4), dtype="float32")
7527+
) -> R.Tuple(R.Tensor((3, 4), dtype="float32")):
7528+
with R.dataflow():
7529+
lv: R.Tensor((3, 4), dtype="float32") = R.take(
7530+
x, R.const(0, "int64"), axis=0, mode="fast"
7531+
)
7532+
lv1: R.Tensor((3, 4), dtype="float32") = R.add(lv, R.const(3.0, "float32"))
7533+
gv: R.Tuple(R.Tensor((3, 4), dtype="float32")) = (lv1,)
7534+
R.output(gv)
7535+
return gv
7536+
7537+
example_args = (torch.randn(1, 3, 4),)
7538+
verify_model(SymSizeInt(dim=1), example_args, {}, Expected)
7539+
verify_model(SymSizeInt(dim=-2), example_args, {}, Expected)
7540+
7541+
75117542
if __name__ == "__main__":
75127543
tvm.testing.main()

0 commit comments

Comments
 (0)