Skip to content

Commit 5842bdb

Browse files
authored
[Relax][PyTorch] Add support for broadcast_to, narrow ops (#17820)
* Update fx_translator.py * Update base_fx_graph_translator.py * Update test_frontend_from_fx.py * Update test_frontend_from_fx.py
1 parent 03ba03e commit 5842bdb

File tree

3 files changed

+58
-0
lines changed

3 files changed

+58
-0
lines changed

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -972,6 +972,12 @@ def _argsort(self, node: fx.Node) -> relax.Var:
972972
descending = node.args[2] if len(node.args) > 2 else node.kwargs.get("descending", False)
973973
return self.block_builder.emit(relax.op.argsort(x, dim, descending))
974974

975+
def _broadcast_to(self, node: fx.Node) -> relax.Var:
976+
args = self.retrieve_args(node)
977+
x = args[0]
978+
shape = args[1] if len(args) > 1 else args[0]
979+
return self.block_builder.emit(relax.op.broadcast_to(x, shape))
980+
975981
def _cat(self, node: fx.Node) -> relax.Var:
976982
args = self.retrieve_args(node)
977983
axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,13 @@ def _flatten_module(self, node: fx.Node) -> relax.Var:
429429
end_dim = module.end_dim
430430
return self._flatten_impl(x, start_dim, end_dim)
431431

432+
def _narrow(self, node: fx.Node) -> relax.Var:
433+
x = self.env[node.args[0]]
434+
dim = node.args[1]
435+
start = node.args[2]
436+
length = node.args[3]
437+
return self.block_builder.emit(relax.op.strided_slice(x, [dim], [start], [length]))
438+
432439
def _numel(self, node: fx.Node) -> relax.Var:
433440
x = self.env[node.args[0]]
434441
shape = self.shape_of(x)
@@ -764,6 +771,7 @@ def create_convert_map(
764771
"where": self._where,
765772
# tensor manipulation
766773
"argsort": self._argsort,
774+
"broadcast_to": self._broadcast_to,
767775
"cat": self._cat,
768776
"chunk": self._chunk,
769777
"concat": self._cat,
@@ -775,6 +783,7 @@ def create_convert_map(
775783
"flatten": self._flatten,
776784
"flip": self._flip,
777785
"gather": self._gather,
786+
"narrow": self._narrow,
778787
"numel": self._numel,
779788
"permute": self._permute,
780789
"repeat": self._repeat,

tests/python/relax/test_frontend_from_fx.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4470,5 +4470,48 @@ def main(
44704470
verify_model(Topk(), [([5, 3], "float32")], {}, Expected)
44714471

44724472

4473+
def test_broadcast_to():
4474+
class BroadcastTo(Module):
4475+
def forward(self, x):
4476+
return torch.broadcast_to(x, (5, 3))
4477+
4478+
@tvm.script.ir_module
4479+
class Expected:
4480+
@R.function
4481+
def main(
4482+
inp_0: R.Tensor((5, 1), dtype="float32"),
4483+
) -> R.Tensor((5, 3), dtype="float32"):
4484+
with R.dataflow():
4485+
lv: R.Tensor((5, 3), dtype="float32") = R.broadcast_to(inp_0, (5, 3))
4486+
gv: R.Tensor((5, 3), dtype="float32") = lv
4487+
R.output(gv)
4488+
return gv
4489+
4490+
verify_model(BroadcastTo(), [([5, 1], "float32")], {}, Expected)
4491+
4492+
4493+
def test_narrow():
4494+
class Narrow(Module):
4495+
def forward(self, x):
4496+
return torch.narrow(x, 1, 0, 2)
4497+
4498+
@tvm.script.ir_module
4499+
class Expected:
4500+
@R.function
4501+
def main(
4502+
inp_0: R.Tensor((5, 3), dtype="float32"),
4503+
) -> R.Tensor((5, 2), dtype="float32"):
4504+
with R.dataflow():
4505+
lv: R.Tensor((5, 2), dtype="float32") = R.strided_slice(
4506+
inp_0, axes=[1], begin=[0], end=[2]
4507+
)
4508+
gv: R.Tensor((5, 2), dtype="float32") = lv
4509+
R.output(gv)
4510+
4511+
return gv
4512+
4513+
verify_model(Narrow(), [([5, 3], "float32")], {}, Expected)
4514+
4515+
44734516
if __name__ == "__main__":
44744517
tvm.testing.main()

0 commit comments

Comments
 (0)