Skip to content

Commit 033df0c

Browse files
authored
added expand converter (#487)
1 parent 2b1827e commit 033df0c

File tree

3 files changed

+45
-0
lines changed

3 files changed

+45
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
### Added
66

7+
- Added converter for ``torch.Tensor.expand``
78
- Added support for custom converters for methods defined outside of ``torch`` module
89
- Added names for TensorRT layers
910
- Added GroupNorm plugin which internally uses PyTorch aten::group_norm

torch2trt/converters/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from .clamp import *
2626
from .compare import *
2727
from .div import *
28+
from .expand import *
2829
from .getitem import *
2930
from .identity import *
3031
from .instance_norm import *

torch2trt/converters/expand.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from torch2trt.torch2trt import *
2+
from torch2trt.module_test import add_module_test
3+
4+
5+
@tensorrt_converter('torch.Tensor.expand')
6+
def convert_expand(ctx):
7+
input = ctx.method_args[0]
8+
sizes = ctx.method_args[1:]
9+
output = ctx.method_return
10+
11+
inshape = tuple(input.shape)[1:] # exclude batch
12+
shape = tuple(output.shape)[1:]
13+
ndim = len(shape)
14+
start = tuple([0]*ndim)
15+
stride = tuple([int(i == o) for i, o in zip(inshape, shape)]) # stride == 1 if dimensions match, 0 otherwise
16+
17+
layer = ctx.network.add_slice(input._trt, start, shape, stride)
18+
19+
output._trt = layer.get_output(0)
20+
21+
22+
class ExpandModule(torch.nn.Module):
23+
def __init__(self, *sizes):
24+
super(ExpandModule, self).__init__()
25+
self.sizes = sizes
26+
27+
def forward(self, x):
28+
return x.expand(*self.sizes)
29+
30+
31+
@add_module_test(torch.float32, torch.device('cuda'), [(1,1,3,3)])
32+
def test_tensor_expand_singledim():
33+
return ExpandModule(1, 3, 3, 3)
34+
35+
36+
@add_module_test(torch.float32, torch.device('cuda'), [(1,1,1,3)])
37+
def test_tensor_expand_multidim():
38+
return ExpandModule(1, 3, 3, 3)
39+
40+
41+
@add_module_test(torch.float32, torch.device('cuda'), [(1,1,1,3)])
42+
def test_tensor_expand_inferdim():
43+
return ExpandModule(1, 3, -1, -1)

0 commit comments

Comments
 (0)