1
1
from typing import Optional , Sequence , Union
2
2
3
+ import numpy as np
3
4
import tensorrt as trt
4
5
from torch .fx .node import Target
5
6
from torch_tensorrt .dynamo ._SourceIR import SourceIR
7
+ from torch_tensorrt .dynamo .conversion import impl
6
8
from torch_tensorrt .dynamo .conversion ._ConversionContext import ConversionContext
7
- from torch_tensorrt .dynamo .conversion .converter_utils import get_trt_tensor
8
- from torch_tensorrt .fx .converters .converter_utils import (
9
- has_dynamic_shape ,
9
+ from torch_tensorrt .dynamo .conversion .converter_utils import (
10
+ get_trt_tensor ,
10
11
set_layer_name ,
11
12
)
12
- from torch_tensorrt .fx .types import TRTTensor
13
+ from torch_tensorrt .dynamo .conversion .impl .shape import get_shape_with_dynamic_shape
14
+ from torch_tensorrt .dynamo .types import TRTTensor
13
15
14
16
"""
15
17
Note: IPaddingLayer is deprecated in TensorRT 8.2 and will be removed in TensorRT 10.0.
18
20
"""
19
21
20
22
21
- def constant_padNd (
23
+ def get_padded_shape_tensors (
22
24
ctx : ConversionContext ,
23
25
target : Union [Target , str ],
24
26
source_ir : Optional [SourceIR ],
25
27
name : str ,
26
28
input : TRTTensor ,
27
29
pad : Sequence [int ],
28
- value : Union [int , float ] = 0 ,
29
30
) -> TRTTensor :
30
- if has_dynamic_shape (input .shape ):
31
- assert input .shape [1 ] != - 1 , "Channel dim can't be dynamic for padding."
32
-
33
31
rank = len (input .shape )
34
-
35
32
if len (pad ) // 2 > rank :
36
33
raise RuntimeError (
37
- f"Trying to pad last { len (pad ) // 2 } dimension but the input only has { rank } dimension ."
34
+ f"Trying to pad last { len (pad ) // 2 } dimensions but the input only has { rank } dimensions ."
38
35
)
39
36
37
+ input_shape_tensor = get_shape_with_dynamic_shape (
38
+ ctx ,
39
+ target ,
40
+ source_ir ,
41
+ name + "_input_shape" ,
42
+ input .shape ,
43
+ input ,
44
+ )
45
+ padded_shape_tensor = input_shape_tensor
46
+
40
47
start_list = [0 ] * rank
41
- new_shape = list (input .shape )
48
+ for i in range (len (pad ) // 2 ):
49
+ dim_index = rank - (i + 1 )
50
+ pad_before = pad [i * 2 ]
51
+ pad_after = pad [i * 2 + 1 ]
42
52
43
- for i in range (0 , len (pad ) // 2 ):
44
- start_list [- i - 1 ] = - pad [i * 2 ]
45
- new_shape [- i - 1 ] += pad [i * 2 ] + pad [i * 2 + 1 ]
53
+ pad_sum = get_trt_tensor (
54
+ ctx , pad_before + pad_after , f"{ name } _pad_sum_{ i } " , dtype = np .int32
55
+ )
56
+ dim_shape = ctx .net .add_slice (
57
+ input_shape_tensor ,
58
+ start = (dim_index ,),
59
+ shape = (1 ,),
60
+ stride = (1 ,),
61
+ ).get_output (0 )
62
+
63
+ new_dim_shape = impl .elementwise .add (
64
+ ctx , target , source_ir , f"{ name } _shape_dim_{ i } " , dim_shape , pad_sum
65
+ )
66
+ start_list [dim_index ] = - pad_before
67
+
68
+ slices = []
69
+ for j in range (rank ):
70
+ if j == dim_index :
71
+ slices .append (new_dim_shape )
72
+ else :
73
+ slices .append (
74
+ ctx .net .add_slice (
75
+ padded_shape_tensor ,
76
+ start = (j ,),
77
+ shape = (1 ,),
78
+ stride = (1 ,),
79
+ ).get_output (0 )
80
+ )
81
+ padded_shape_tensor = impl .cat .cat (
82
+ ctx , target , source_ir , f"{ name } _cat" , slices , 0
83
+ )
84
+
85
+ start_indices_tensor = get_trt_tensor (
86
+ ctx ,
87
+ np .array (start_list , dtype = np .int32 ),
88
+ f"{ name } _start_indices_tensor" ,
89
+ dtype = np .int32 ,
90
+ )
91
+
92
+ return start_indices_tensor , padded_shape_tensor
93
+
94
+
95
+ def constant_padNd (
96
+ ctx : ConversionContext ,
97
+ target : Union [Target , str ],
98
+ source_ir : Optional [SourceIR ],
99
+ name : str ,
100
+ input : TRTTensor ,
101
+ pad : Sequence [int ],
102
+ value : Union [int , float ] = 0 ,
103
+ ) -> TRTTensor :
104
+
105
+ rank = len (input .shape )
106
+
107
+ start_indices_tensor , padded_shape_tensor = get_padded_shape_tensors (
108
+ ctx , target , source_ir , name , input , pad
109
+ )
46
110
47
111
stride_list = [1 ] * rank
112
+ stride_tensor = get_trt_tensor (
113
+ ctx ,
114
+ np .array (stride_list , dtype = np .int32 ),
115
+ f"{ name } _stride_tensor" ,
116
+ dtype = np .int32 ,
117
+ )
118
+
48
119
layer = ctx .net .add_slice (
49
- input ,
50
- start = tuple (start_list ),
51
- shape = tuple (new_shape ),
52
- stride = tuple (stride_list ),
120
+ input , start = trt .Dims (), shape = trt .Dims (), stride = trt .Dims ()
53
121
)
122
+ layer .set_input (1 , start_indices_tensor )
123
+ layer .set_input (2 , padded_shape_tensor )
124
+ layer .set_input (3 , stride_tensor )
125
+
54
126
value_const = get_trt_tensor (ctx , value , f"{ name } _value" , input .dtype )
55
127
layer .set_input (4 , value_const )
56
128
layer .mode = trt .SampleMode .FILL
@@ -67,30 +139,26 @@ def reflection_padNd(
67
139
input : TRTTensor ,
68
140
padding : Sequence [int ],
69
141
) -> TRTTensor :
70
- if has_dynamic_shape (input .shape ):
71
- assert input .shape [1 ] != - 1 , "Channel dim can't be dynamic for padding."
72
-
73
142
rank = len (input .shape )
74
143
75
- if len (padding ) // 2 > rank :
76
- raise RuntimeError (
77
- f"Trying to pad last { len (padding ) // 2 } dimension but the input only has { rank } dimension."
78
- )
79
-
80
- start_list = [0 ] * rank
81
- new_shape = list (input .shape )
82
-
83
- for i in range (0 , len (padding ) // 2 ):
84
- start_list [- i - 1 ] = - padding [i * 2 ]
85
- new_shape [- i - 1 ] += padding [i * 2 ] + padding [i * 2 + 1 ]
144
+ start_indices_tensor , padded_shape_tensor = get_padded_shape_tensors (
145
+ ctx , target , source_ir , name , input , padding
146
+ )
86
147
87
148
stride_list = [1 ] * rank
149
+ stride_tensor = get_trt_tensor (
150
+ ctx ,
151
+ np .array (stride_list , dtype = np .int32 ),
152
+ f"{ name } _stride_tensor" ,
153
+ dtype = np .int32 ,
154
+ )
155
+
88
156
layer = ctx .net .add_slice (
89
- input ,
90
- start = tuple (start_list ),
91
- shape = tuple (new_shape ),
92
- stride = tuple (stride_list ),
157
+ input , start = trt .Dims (), shape = trt .Dims (), stride = trt .Dims ()
93
158
)
159
+ layer .set_input (1 , start_indices_tensor )
160
+ layer .set_input (2 , padded_shape_tensor )
161
+ layer .set_input (3 , stride_tensor )
94
162
layer .mode = trt .SampleMode .REFLECT
95
163
96
164
set_layer_name (layer , target , name , source_ir )
@@ -105,30 +173,26 @@ def replication_padNd(
105
173
input : TRTTensor ,
106
174
padding : Sequence [int ],
107
175
) -> TRTTensor :
108
- if has_dynamic_shape (input .shape ):
109
- assert input .shape [1 ] != - 1 , "Channel dim can't be dynamic for padding."
110
-
111
176
rank = len (input .shape )
112
177
113
- if len (padding ) // 2 > rank :
114
- raise RuntimeError (
115
- f"Trying to pad last { len (padding ) // 2 } dimension but the input only has { rank } dimension."
116
- )
117
-
118
- start_list = [0 ] * rank
119
- new_shape = list (input .shape )
120
-
121
- for i in range (0 , len (padding ) // 2 ):
122
- start_list [- i - 1 ] = - padding [i * 2 ]
123
- new_shape [- i - 1 ] += padding [i * 2 ] + padding [i * 2 + 1 ]
178
+ start_indices_tensor , padded_shape_tensor = get_padded_shape_tensors (
179
+ ctx , target , source_ir , name , input , padding
180
+ )
124
181
125
182
stride_list = [1 ] * rank
183
+ stride_tensor = get_trt_tensor (
184
+ ctx ,
185
+ np .array (stride_list , dtype = np .int32 ),
186
+ f"{ name } _stride_tensor" ,
187
+ dtype = np .int32 ,
188
+ )
189
+
126
190
layer = ctx .net .add_slice (
127
- input ,
128
- start = tuple (start_list ),
129
- shape = tuple (new_shape ),
130
- stride = tuple (stride_list ),
191
+ input , start = trt .Dims (), shape = trt .Dims (), stride = trt .Dims ()
131
192
)
193
+ layer .set_input (1 , start_indices_tensor )
194
+ layer .set_input (2 , padded_shape_tensor )
195
+ layer .set_input (3 , stride_tensor )
132
196
layer .mode = trt .SampleMode .CLAMP
133
197
134
198
set_layer_name (layer , target , name , source_ir )
@@ -141,32 +205,28 @@ def circular_padNd(
141
205
source_ir : Optional [SourceIR ],
142
206
name : str ,
143
207
input : TRTTensor ,
144
- pad : Sequence [int ],
208
+ padding : Sequence [int ],
145
209
) -> TRTTensor :
146
- if has_dynamic_shape (input .shape ):
147
- assert input .shape [1 ] != - 1 , "Channel dim can't be dynamic for padding."
148
-
149
210
rank = len (input .shape )
150
211
151
- if len (pad ) // 2 > rank :
152
- raise RuntimeError (
153
- f"Trying to pad last { len (pad ) // 2 } dimension but the input only has { rank } dimension."
154
- )
155
-
156
- start_list = [0 ] * rank
157
- new_shape = list (input .shape )
158
-
159
- for i in range (0 , len (pad ) // 2 ):
160
- start_list [- i - 1 ] = - pad [i * 2 ]
161
- new_shape [- i - 1 ] += pad [i * 2 ] + pad [i * 2 + 1 ]
212
+ start_indices_tensor , padded_shape_tensor = get_padded_shape_tensors (
213
+ ctx , target , source_ir , name , input , padding
214
+ )
162
215
163
216
stride_list = [1 ] * rank
217
+ stride_tensor = get_trt_tensor (
218
+ ctx ,
219
+ np .array (stride_list , dtype = np .int32 ),
220
+ f"{ name } _stride_tensor" ,
221
+ dtype = np .int32 ,
222
+ )
223
+
164
224
layer = ctx .net .add_slice (
165
- input ,
166
- start = tuple (start_list ),
167
- shape = tuple (new_shape ),
168
- stride = tuple (stride_list ),
225
+ input , start = trt .Dims (), shape = trt .Dims (), stride = trt .Dims ()
169
226
)
227
+ layer .set_input (1 , start_indices_tensor )
228
+ layer .set_input (2 , padded_shape_tensor )
229
+ layer .set_input (3 , stride_tensor )
170
230
layer .mode = trt .SampleMode .WRAP
171
231
172
232
set_layer_name (layer , target , name , source_ir )
0 commit comments