3
3
#include " core/conversion/tensorcontainer/TensorContainer.h"
4
4
#include " core/util/prelude.h"
5
5
#include " core/util/trt_util.h"
6
+ #include " plugins/checkshape_plugin.h"
6
7
#include " torch/torch.h"
7
8
8
9
#include < ATen/ATen.h>
@@ -15,24 +16,93 @@ namespace converters {
15
16
namespace impl {
16
17
namespace {
17
18
19
+ nvinfer1::ILayer* create_plugin (
20
+ ConversionCtx* ctx,
21
+ const torch::jit::Node* n,
22
+ nvinfer1::ITensor* inShape,
23
+ nvinfer1::ITensor* expandShape,
24
+ int32_t in_rank,
25
+ int32_t expand_rank,
26
+ const char * name) {
27
+ auto creator = new plugins::CheckShapePluginCreator ();
28
+ std::vector<nvinfer1::PluginField> fields;
29
+ nvinfer1::PluginField input_rank (" input_rank" , &in_rank, nvinfer1::PluginFieldType::kINT32 , 1 );
30
+ nvinfer1::PluginField output_rank (" expand_rank" , &expand_rank, nvinfer1::PluginFieldType::kINT32 , 1 );
31
+ fields.push_back (input_rank);
32
+ fields.push_back (output_rank);
33
+ nvinfer1::PluginFieldCollection collection;
34
+ collection.nbFields = fields.size ();
35
+ collection.fields = fields.data ();
36
+ auto plugin = creator->createPlugin (name, &collection);
37
+
38
+ nvinfer1::ITensor* inputs[] = {inShape, expandShape};
39
+ auto expandShape_layer = ctx->net ->addPluginV2 (inputs, 2 , *plugin);
40
+ TRTORCH_CHECK (expandShape_layer, " Unable to create interpolation plugin from node" << *n);
41
+
42
+ expandShape_layer->setName (" CheckShapePlugin" );
43
+ return expandShape_layer;
44
+ }
45
+
46
+ void addSliceInput (nvinfer1::Dims& dims, int idx, ConversionCtx* ctx, nvinfer1::ISliceLayer* slice) {
47
+ int32_t rank = static_cast <int32_t >(dims.nbDims );
48
+ int32_t * tmp = new int32_t [rank];
49
+ for (int i=0 ;i<rank;i++)
50
+ tmp[i] = dims.d [i];
51
+ const nvinfer1::Dims d{1 , {rank}};
52
+ const nvinfer1::Weights w{nvinfer1::DataType::kINT32 , tmp, rank};
53
+ auto t = ctx->net ->addConstant (d, w)->getOutput (0 );
54
+ slice->setInput (idx, *t);
55
+ }
56
+
57
+ nvinfer1::ITensor* vec2Tensor (int32_t *dim, int rank, ConversionCtx* ctx){
58
+ const nvinfer1::Dims d{1 , {static_cast <int32_t >(rank)}};
59
+ const nvinfer1::Weights w{nvinfer1::DataType::kINT32 , dim, rank};
60
+ return ctx->net ->addConstant (d, w)->getOutput (0 );
61
+ }
62
+
63
+ nvinfer1::ITensor * concat (int max_rank, int old_rank, ConversionCtx* ctx, nvinfer1::ITensor*tensor){
64
+ if (max_rank - old_rank > 0 ){
65
+ int32_t * tmp = new int32_t [max_rank - old_rank];
66
+ for (int i=0 ;i<(max_rank - old_rank);i++)
67
+ tmp[i] = 1 ;
68
+ auto max_rank_tensor = vec2Tensor (tmp, max_rank - old_rank, ctx);
69
+ auto in_shape_tensor = ctx->net ->addShape (*tensor)->getOutput (0 );
70
+ nvinfer1::ITensor* const args[2 ] = {max_rank_tensor, in_shape_tensor};
71
+ return ctx->net ->addConcatenation (args, 2 )->getOutput (0 );
72
+ }else { // max_rank - old_rank == 0
73
+ return ctx->net ->addShape (*tensor)->getOutput (0 );
74
+ }
75
+ }
76
+
18
77
bool add_expand (ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in, nvinfer1::Dims expandedDims) {
19
78
auto input_dims = in->getDimensions ();
20
79
TRTORCH_CHECK (
21
80
input_dims.nbDims <= expandedDims.nbDims ,
22
81
" Number of dimensions of the desired expansion must be greater than or equal to the number of input dimensions" );
23
82
24
83
// Validate the expansion. Eg: an input of [3, 1] can be expanded to [1, 3, 4] but not [3, 4, 1]
25
- for (int64_t i = expandedDims.nbDims - 1 ; i >= 0 ; --i) {
84
+ for (int i = expandedDims.nbDims - 1 ; i >= 0 ; --i) {
26
85
int64_t offset = expandedDims.nbDims - 1 - i;
27
86
int64_t dim = input_dims.nbDims - 1 - offset;
28
87
int64_t size = (dim >= 0 ) ? input_dims.d [dim] : 1 ;
29
88
int64_t targetSize = expandedDims.d [i];
30
- if (size != targetSize) {
31
- if (size != 1 ) {
32
- TRTORCH_THROW_ERROR (
33
- " The expanded size of tensor (" << targetSize << " )"
34
- << " must match the existing size (" << size << " )"
35
- << " at dimension " << i);
89
+ if (targetSize != -1 ){
90
+ if (size != targetSize) {
91
+ if (size != 1 ) {
92
+ TRTORCH_THROW_ERROR (
93
+ " The expanded size of tensor (" << targetSize << " )"
94
+ << " must match the existing size (" << size << " )"
95
+ << " at dimension " << i);
96
+ }
97
+ }
98
+ }else {
99
+ if (dim < 0 ){
100
+ TRTORCH_THROW_ERROR (" The expanded size of the tensor (" << \
101
+ targetSize << " ) isn't allowed in a leading, non-existing dimension " << \
102
+ i);
103
+ }else {
104
+ // in(3, 1), expand(3, -1, 4) -> expand(3, 3, 4)
105
+ expandedDims.d [i] = input_dims.d [dim];
36
106
}
37
107
}
38
108
}
@@ -41,10 +111,10 @@ bool add_expand(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor
41
111
if (num_expand_dims > 0 ) {
42
112
nvinfer1::Dims reshape_dims;
43
113
reshape_dims.nbDims = expandedDims.nbDims ;
44
- for (int64_t i = 0 ; i < num_expand_dims; i++) {
114
+ for (int i = 0 ; i < num_expand_dims; i++) {
45
115
reshape_dims.d [i] = 1 ;
46
116
}
47
- for (int64_t i = 0 ; i < input_dims.nbDims ; i++) {
117
+ for (int i = 0 ; i < input_dims.nbDims ; i++) {
48
118
reshape_dims.d [num_expand_dims + i] = input_dims.d [i];
49
119
}
50
120
// Add a reshape layer to expand dims
@@ -60,7 +130,7 @@ bool add_expand(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor
60
130
61
131
// Set the stride of non singleton dimension to 1
62
132
std::vector<int64_t > strides_vec (expandedDims.nbDims , 0 );
63
- for (int64_t i = 0 ; i < expandedDims.nbDims ; i++) {
133
+ for (int i = 0 ; i < expandedDims.nbDims ; i++) {
64
134
strides_vec[i] = (in->getDimensions ().d [i] != 1 );
65
135
}
66
136
@@ -76,6 +146,61 @@ bool add_expand(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor
76
146
return true ;
77
147
}
78
148
149
+ bool add_expand_dynamic (ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in, nvinfer1::ITensor* expandedDimsTensor){
150
+ auto input_shape_tensor = ctx->net ->addShape (*in)->getOutput (0 );
151
+ auto input_rank = in->getDimensions ().nbDims ;
152
+ auto output_rank = expandedDimsTensor->getDimensions ().d [0 ];
153
+ TRTORCH_CHECK (
154
+ input_rank <= output_rank,
155
+ " Number of dimensions of the desired expansion must be greater than or equal to the number of input dimensions" );
156
+
157
+ // add a plugin to check expandedDimsTensor whether match input_shape_tensor
158
+ auto expandShape_layer = create_plugin (ctx, n, input_shape_tensor, expandedDimsTensor, input_rank, output_rank, " expandShape" );
159
+ auto _tensor = expandShape_layer->getOutput (0 );
160
+
161
+ size_t max_rank = std::max (input_rank, output_rank);
162
+
163
+ // Dimensions are right alignment
164
+ auto new_input_shape_tensor = concat (max_rank, input_rank, ctx, in);
165
+ // LOG_DEBUG("Expand layer output tensor shape: " << new_output_shape_tensor->getDimensions());
166
+ auto new_output_shape_tensor = expandedDimsTensor;
167
+
168
+ // Add a reshape layer to expand dims
169
+ auto shuffle = ctx->net ->addShuffle (*in);
170
+ shuffle->setInput (1 , *new_input_shape_tensor);
171
+
172
+ // Start the slicing from beginning of tensor since this is an expand layer
173
+ std::vector<int64_t > start_vec (max_rank, 0 );
174
+ nvinfer1::Dims starts_dim = util::toDims (c10::IntArrayRef (start_vec));
175
+
176
+ // compute sizes = max(x,y).
177
+ auto sizes = ctx->net ->addElementWise (*new_input_shape_tensor, *new_output_shape_tensor, nvinfer1::ElementWiseOperation::kMAX )->getOutput (0 );
178
+ nvinfer1::Dims sizes_dim{-1 , {}};
179
+ sizes_dim.nbDims = max_rank;
180
+
181
+ // Compute (x > 1 ? 1 : 0) for x in newDims, assuming positive x, using only TensorRT operations.
182
+ // min(1, sub(input_shape, 1))
183
+ int32_t * one_vector_tmp = new int32_t [1 ];
184
+ one_vector_tmp[0 ] = 1 ;
185
+ auto one_vector = vec2Tensor (one_vector_tmp, 1 , ctx);
186
+ auto x_sub_one = ctx->net ->addElementWise (*new_input_shape_tensor, *one_vector, nvinfer1::ElementWiseOperation::kSUB )->getOutput (0 );
187
+ auto strides = ctx->net ->addElementWise (*one_vector, *x_sub_one, nvinfer1::ElementWiseOperation::kMIN )->getOutput (0 );
188
+ nvinfer1::Dims strides_dim{-1 , {}};
189
+ strides_dim.nbDims = max_rank;
190
+
191
+ // Slice layer does the expansion in TRT. Desired output size is specified by expandedDimsTensor
192
+ auto slice = ctx->net ->addSlice (*shuffle->getOutput (0 ), starts_dim, sizes_dim, strides_dim);
193
+ addSliceInput (starts_dim, 1 , ctx, slice);
194
+ slice->setInput (2 , *sizes);
195
+ slice->setInput (3 , *strides);
196
+
197
+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], slice->getOutput (0 ));
198
+
199
+ LOG_DEBUG (" Expand layer output tensor shape: " << out_tensor->getDimensions ());
200
+
201
+ return true ;
202
+ }
203
+
79
204
auto expand_registrations TRTORCH_UNUSED =
80
205
RegisterNodeConversionPatterns ()
81
206
.pattern({" aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> (Tensor(a))" ,
@@ -85,51 +210,75 @@ auto expand_registrations TRTORCH_UNUSED =
85
210
auto expanded_size = args[1 ].unwrapToIntList ();
86
211
auto expandedDims = util::toDims (expanded_size);
87
212
LOG_DEBUG (" (expand layer) Expand input from " << input_dims << " to " << expandedDims);
88
- return add_expand (ctx, n, in, expandedDims);
213
+ if (ctx->input_is_dynamic ){
214
+ int expanded_size_rank = static_cast <int >(expanded_size.size ());
215
+ int32_t * tmp = new int32_t [expanded_size_rank];
216
+ for (int i=0 ;i<expanded_size_rank;i++)
217
+ tmp[i] = expanded_size[i];
218
+ auto expandedDimsTensor = vec2Tensor (tmp, expanded_size_rank, ctx);
219
+ return add_expand_dynamic (ctx, n, in, expandedDimsTensor);
220
+ }else {
221
+ return add_expand (ctx, n, in, expandedDims);
222
+ }
89
223
}})
90
224
.pattern({" aten::expand_as(Tensor(a) self, Tensor other) -> (Tensor(a))" ,
91
225
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
92
- // TODO: Currently expand supports static shapes. Need to explore if the same code can be extended
93
- // to dynamic expansion.
94
226
auto in = args[0 ].ITensor ();
95
227
auto input_dims = in->getDimensions ();
96
228
auto targetTensor = args[1 ].ITensor ();
97
229
auto targetDims = targetTensor->getDimensions ();
98
230
LOG_DEBUG (" (expand_as layer) Expand input from " << input_dims << " to " << targetDims);
99
- return add_expand (ctx, n, in, targetDims);
231
+ if (ctx->input_is_dynamic ){
232
+ return add_expand_dynamic (ctx, n, in, ctx->net ->addShape (*targetTensor)->getOutput (0 ));
233
+ }else {
234
+ return add_expand (ctx, n, in, targetDims);
235
+ }
236
+
100
237
}})
101
238
.pattern({" aten::repeat(Tensor self, int[] repeats) -> (Tensor)" ,
102
239
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
103
240
auto in = args[0 ].ITensor ();
104
241
auto input_dims = in->getDimensions ();
105
242
auto repeats = args[1 ].unwrapToIntList ().vec ();
243
+ int repeats_rank = repeats.size ();
106
244
TRTORCH_CHECK (
107
- static_cast < int64_t >(repeats. size ()) >= input_dims.nbDims ,
245
+ repeats_rank >= input_dims.nbDims ,
108
246
" Number of repeat dimensions cannot be smaller than number of input dimensions" );
109
- auto num_expand_dims = repeats.size () - input_dims.nbDims ;
110
- if (num_expand_dims > 0 ) {
111
- nvinfer1::Dims reshape_dims;
112
- reshape_dims.nbDims = repeats.size ();
113
- for (size_t i = 0 ; i < num_expand_dims; i++) {
114
- reshape_dims.d [i] = 1 ;
115
- }
116
- for (int64_t i = 0 ; i < input_dims.nbDims ; i++) {
117
- reshape_dims.d [num_expand_dims + i] = input_dims.d [i];
118
- }
247
+ auto num_expand_dims = repeats_rank - input_dims.nbDims ;
248
+
249
+ if (ctx->input_is_dynamic ){
250
+ int input_rank = input_dims.nbDims ;
251
+ int output_rank= repeats_rank;
252
+ auto new_input_shape_tensor = concat (output_rank, input_rank, ctx, in);
253
+
119
254
// Add a reshape layer to expand dims
120
- auto reshape_layer = ctx->net ->addShuffle (*in);
121
- reshape_layer->setReshapeDimensions (reshape_dims);
122
- in = reshape_layer->getOutput (0 );
123
- LOG_DEBUG (" Input reshaped to : " << in->getDimensions () << " from " << input_dims);
255
+ auto shuffle = ctx->net ->addShuffle (*in);
256
+ shuffle->setInput (1 , *new_input_shape_tensor);
257
+ in = shuffle->getOutput (0 );
258
+ }else {
259
+ if (num_expand_dims > 0 ) {
260
+ nvinfer1::Dims reshape_dims;
261
+ reshape_dims.nbDims = repeats.size ();
262
+ for (int i = 0 ; i < num_expand_dims; i++) {
263
+ reshape_dims.d [i] = 1 ;
264
+ }
265
+ for (int i = 0 ; i < input_dims.nbDims ; i++) {
266
+ reshape_dims.d [num_expand_dims + i] = input_dims.d [i];
267
+ }
268
+ // Add a reshape layer to expand dims
269
+ auto reshape_layer = ctx->net ->addShuffle (*in);
270
+ reshape_layer->setReshapeDimensions (reshape_dims);
271
+ in = reshape_layer->getOutput (0 );
272
+ LOG_DEBUG (" Input reshaped to : " << in->getDimensions () << " from " << input_dims);
273
+ }
274
+ LOG_DEBUG (" Repeats: " << repeats);
124
275
}
125
276
126
- LOG_DEBUG (" Repeats: " << repeats);
127
-
128
277
// Concat across all repeat axes.
129
278
// TODO: Implementation might not be performant. Explore other strategies to improve performance.
130
- for (int64_t i = repeats.size () - 1 ; i >= 0 ; --i) {
279
+ for (int i = repeats.size () - 1 ; i >= 0 ; --i) {
131
280
std::vector<nvinfer1::ITensor*> tensors_vec;
132
- for (int64_t j = 0 ; j < repeats[i]; j++) {
281
+ for (int j = 0 ; j < repeats[i]; j++) {
133
282
tensors_vec.push_back (in);
134
283
}
135
284
auto concat_layer = ctx->net ->addConcatenation (tensors_vec.data (), tensors_vec.size ());
@@ -139,8 +288,7 @@ auto expand_registrations TRTORCH_UNUSED =
139
288
140
289
auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], in);
141
290
142
- LOG_DEBUG (" Repeat layer output tensor shape: " << out->getDimensions ());
143
-
291
+ LOG_DEBUG (" Repeat layer output tensor shape: " << in->getDimensions ());
144
292
return true ;
145
293
}});
146
294
@@ -149,4 +297,4 @@ auto expand_registrations TRTORCH_UNUSED =
149
297
} // namespace converters
150
298
} // namespace conversion
151
299
} // namespace core
152
- } // namespace trtorch
300
+ } // namespace trtorch
0 commit comments