3
3
import torch .nn as nn
4
4
from parameterized import parameterized
5
5
from torch .testing ._internal .common_utils import run_tests
6
- from torch_tensorrt .dynamo .conversion import UnsupportedOperatorException
7
-
8
6
from torch_tensorrt import Input
7
+ from torch_tensorrt .dynamo .conversion import UnsupportedOperatorException
9
8
10
9
from .harness import DispatchTestCase
11
10
@@ -29,16 +28,32 @@ def forward(self, x):
29
28
inputs = [torch .randn (1 , 2 , 3 )]
30
29
self .run_test (Unsqueeze (dim ), inputs )
31
30
32
- # Testing with more than one dynamic dims results in following error:
33
- # AssertionError: Currently we don't support unsqueeze with more than one dynamic dims.
34
-
35
31
@parameterized .expand (
36
32
[
37
- ("negative_dim_dynamic" , - 4 ),
38
- ("positive_dim_dynamic" , 1 ),
33
+ ("1_dynamic_shape_2d_-3" , - 3 , (2 , 5 ), (3 , 5 ), (4 , 5 )),
34
+ ("1_dynamic_shape_2d_-2" , - 2 , (2 , 3 ), (2 , 4 ), (2 , 5 )),
35
+ ("1_dynamic_shape_2d_-1" , - 1 , (2 , 3 ), (2 , 4 ), (2 , 5 )),
36
+ ("1_dynamic_shape_2d_0" , 0 , (2 , 3 ), (2 , 4 ), (2 , 5 )),
37
+ ("1_dynamic_shape_2d_1" , 1 , (2 , 3 ), (2 , 4 ), (2 , 5 )),
38
+ ("1_dynamic_shape_2d_2" , 2 , (2 , 3 ), (2 , 4 ), (2 , 5 )),
39
+ ("2_dynamic_shape_3d_-1" , - 1 , (2 , 2 , 3 ), (4 , 3 , 3 ), (5 , 5 , 3 )),
40
+ ("2_dynamic_shape_3d_0" , 2 , (2 , 2 , 3 ), (4 , 3 , 3 ), (5 , 5 , 3 )),
41
+ ("2_dynamic_shape_3d_1" , 1 , (2 , 2 , 3 ), (4 , 3 , 3 ), (5 , 6 , 3 )),
42
+ ("2_dynamic_shape_3d_2" , 2 , (2 , 2 , 3 ), (4 , 3 , 3 ), (6 , 5 , 3 )),
43
+ ("4_dynamic_shape_4d_-4" , - 4 , (1 , 2 , 3 , 4 ), (2 , 2 , 3 , 5 ), (3 , 3 , 5 , 5 )),
44
+ ("4_dynamic_shape_4d_-3" , - 3 , (1 , 2 , 3 , 4 ), (2 , 2 , 3 , 5 ), (3 , 3 , 5 , 5 )),
45
+ ("4_dynamic_shape_4d_-2" , - 2 , (1 , 2 , 3 , 4 ), (2 , 2 , 3 , 5 ), (4 , 3 , 5 , 6 )),
46
+ ("4_dynamic_shape_4d_-1" , - 1 , (1 , 2 , 3 , 4 ), (2 , 2 , 3 , 5 ), (4 , 3 , 5 , 6 )),
47
+ ("4_dynamic_shape_4d_0" , 0 , (1 , 2 , 3 , 4 ), (2 , 2 , 5 , 7 ), (2 , 3 , 6 , 8 )),
48
+ ("4_dynamic_shape_4d_1" , 1 , (1 , 2 , 3 , 4 ), (2 , 2 , 3 , 5 ), (3 , 3 , 5 , 5 )),
49
+ ("4_dynamic_shape_4d_2" , 2 , (1 , 2 , 3 , 4 ), (2 , 2 , 3 , 5 ), (3 , 3 , 5 , 5 )),
50
+ ("4_dynamic_shape_4d_3" , 3 , (1 , 2 , 3 , 4 ), (2 , 2 , 3 , 5 ), (3 , 3 , 5 , 5 )),
51
+ ("4_dynamic_shape_4d_4" , 4 , (1 , 2 , 3 , 4 ), (2 , 2 , 3 , 5 ), (3 , 3 , 5 , 5 )),
39
52
]
40
53
)
41
- def test_unsqueeze_with_dynamic_shape (self , _ , dim ):
54
+ def test_unsqueeze_with_dynamic_shape (
55
+ self , _ , dim , min_shape , opt_shape , max_shape
56
+ ):
42
57
class Unsqueeze (nn .Module ):
43
58
def __init__ (self , dim ):
44
59
super ().__init__ ()
@@ -49,9 +64,10 @@ def forward(self, x):
49
64
50
65
input_specs = [
51
66
Input (
52
- shape = (- 1 , 2 , 3 ),
53
67
dtype = torch .float32 ,
54
- shape_ranges = [((1 , 2 , 3 ), (2 , 2 , 3 ), (3 , 2 , 3 ))],
68
+ min_shape = min_shape ,
69
+ opt_shape = opt_shape ,
70
+ max_shape = max_shape ,
55
71
),
56
72
]
57
73
self .run_test_with_dynamic_shape (Unsqueeze (dim ), input_specs )
0 commit comments