1+ from torch2trt .torch2trt import *
2+ from torch2trt .module_test import add_module_test
3+
4+
5+ @tensorrt_converter ('torch.Tensor.expand' )
6+ def convert_expand (ctx ):
7+ input = ctx .method_args [0 ]
8+ sizes = ctx .method_args [1 :]
9+ output = ctx .method_return
10+
11+ inshape = tuple (input .shape )[1 :] # exclude batch
12+ shape = tuple (output .shape )[1 :]
13+ ndim = len (shape )
14+ start = tuple ([0 ]* ndim )
15+ stride = tuple ([int (i == o ) for i , o in zip (inshape , shape )]) # stride == 1 if dimensions match, 0 otherwise
16+
17+ layer = ctx .network .add_slice (input ._trt , start , shape , stride )
18+
19+ output ._trt = layer .get_output (0 )
20+
21+
22+ class ExpandModule (torch .nn .Module ):
23+ def __init__ (self , * sizes ):
24+ super (ExpandModule , self ).__init__ ()
25+ self .sizes = sizes
26+
27+ def forward (self , x ):
28+ return x .expand (* self .sizes )
29+
30+
31+ @add_module_test (torch .float32 , torch .device ('cuda' ), [(1 ,1 ,3 ,3 )])
32+ def test_tensor_expand_singledim ():
33+ return ExpandModule (1 , 3 , 3 , 3 )
34+
35+
36+ @add_module_test (torch .float32 , torch .device ('cuda' ), [(1 ,1 ,1 ,3 )])
37+ def test_tensor_expand_multidim ():
38+ return ExpandModule (1 , 3 , 3 , 3 )
39+
40+
41+ @add_module_test (torch .float32 , torch .device ('cuda' ), [(1 ,1 ,1 ,3 )])
42+ def test_tensor_expand_inferdim ():
43+ return ExpandModule (1 , 3 , - 1 , - 1 )
0 commit comments