@@ -15,6 +15,18 @@ namespace converters {
15
15
namespace impl {
16
16
namespace {
17
17
18
+ nvinfer1::ITensor* concat (int max_rank, int old_rank, ConversionCtx* ctx, nvinfer1::ITensor* tensor) {
19
+ if (max_rank - old_rank > 0 ) {
20
+ torch::Tensor thOne = torch::tensor (std::vector<int32_t >(max_rank - old_rank, 1 ), torch::kInt32 );
21
+ auto one_tensor = tensor_to_const (ctx, thOne);
22
+ auto in_shape_tensor = ctx->net ->addShape (*tensor)->getOutput (0 );
23
+ nvinfer1::ITensor* const args[2 ] = {one_tensor, in_shape_tensor};
24
+ return ctx->net ->addConcatenation (args, 2 )->getOutput (0 );
25
+ } else { // max_rank - old_rank == 0
26
+ return ctx->net ->addShape (*tensor)->getOutput (0 );
27
+ }
28
+ }
29
+
18
30
bool add_expand (ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in, nvinfer1::Dims expandedDims) {
19
31
auto input_dims = in->getDimensions ();
20
32
TRTORCH_CHECK (
@@ -27,12 +39,26 @@ bool add_expand(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor
27
39
int64_t dim = input_dims.nbDims - 1 - offset;
28
40
int64_t size = (dim >= 0 ) ? input_dims.d [dim] : 1 ;
29
41
int64_t targetSize = expandedDims.d [i];
30
- if (size != targetSize) {
31
- if (size != 1 ) {
42
+ // In expand layer passing -1 as the size for a dimension means not changing the size of that dimension.
43
+ if (targetSize != -1 ) {
44
+ if (size != targetSize) {
45
+ if (size != 1 ) {
46
+ TRTORCH_THROW_ERROR (
47
+ " The expanded size of tensor (" << targetSize << " )"
48
+ << " must match the existing size (" << size << " )"
49
+ << " at dimension " << i);
50
+ }
51
+ }
52
+ } else {
53
+ // For the new dimensions, the size cannot be set to -1. Eg: an input of [3, 1] can be expanded to [3, -1, 4] but
54
+ // not [-1, 3, 4].
55
+ if (dim < 0 ) {
32
56
TRTORCH_THROW_ERROR (
33
- " The expanded size of tensor (" << targetSize << " )"
34
- << " must match the existing size (" << size << " )"
35
- << " at dimension " << i);
57
+ " The expanded size of the tensor (" << targetSize << " ) isn't allowed in a leading, non-existing dimension "
58
+ << i);
59
+ } else {
60
+ // in(3, 1), expand(3, -1, 4) -> expand(3, 3, 4)
61
+ expandedDims.d [i] = input_dims.d [dim];
36
62
}
37
63
}
38
64
}
@@ -76,77 +102,192 @@ bool add_expand(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor
76
102
return true ;
77
103
}
78
104
105
+ bool add_expand_dynamic (
106
+ ConversionCtx* ctx,
107
+ const torch::jit::Node* n,
108
+ nvinfer1::ITensor* in,
109
+ nvinfer1::ITensor* expandedDimsTensor,
110
+ nvinfer1::Dims expandedDims,
111
+ bool is_expand_layer) {
112
+ auto input_dims = in->getDimensions ();
113
+ auto input_rank = in->getDimensions ().nbDims ;
114
+ auto output_rank = expandedDims.nbDims ;
115
+ TRTORCH_CHECK (
116
+ input_rank <= output_rank,
117
+ " Number of dimensions of the desired expansion must be greater than or equal to the number of input dimensions" );
118
+
119
+ /* TODO: When the inputs are dynamic, some dimensions of the inputs are indeterminate before setBindingDimensions. For
120
+ these indeterminate dimensions, we don't validate the expansion. Eg: For an input of [3, -1], we omit the
121
+ validation of the second dimension. Need to explore a better way to validate the expansion.
122
+ */
123
+ // Validate the expansion. Eg: an input of [3, 1] can be expanded to [1, 3, 4] but not [3, 4, 1]
124
+ for (int64_t i = expandedDims.nbDims - 1 ; i >= 0 ; --i) {
125
+ int64_t offset = expandedDims.nbDims - 1 - i;
126
+ int64_t dim = input_dims.nbDims - 1 - offset;
127
+ int64_t size = (dim >= 0 ) ? input_dims.d [dim] : 1 ;
128
+ int64_t targetSize = expandedDims.d [i];
129
+ // Passing -1 as the size for a dimension means not changing the size of that dimension in expand layer.
130
+ if (targetSize != -1 ) {
131
+ if (size != targetSize) {
132
+ // if size == -1, we can't validate the expansion before setBindingDimensions.
133
+ if (!(size == -1 || size == 1 )) {
134
+ TRTORCH_THROW_ERROR (
135
+ " The expanded size of tensor (" << targetSize << " )"
136
+ << " must match the existing size (" << size << " )"
137
+ << " at dimension " << i);
138
+ }
139
+ }
140
+ } else {
141
+ // In dynamic expand layer, for the new dimensions, the size cannot be set to -1. Eg: an input of [3, 1] can be
142
+ // expanded to [3, -1, 4] but not [-1, 3, 4].
143
+ if (is_expand_layer && dim < 0 ) {
144
+ TRTORCH_THROW_ERROR (
145
+ " The expanded size of the tensor (" << targetSize << " ) isn't allowed in a leading, non-existing dimension "
146
+ << i);
147
+ }
148
+ }
149
+ }
150
+
151
+ size_t max_rank = std::max (input_rank, output_rank);
152
+
153
+ // Dimensions are right alignment. Eg: an input of [3, 1] and max_rank = 4, the result of concat is [1, 1, 3, 1]
154
+ auto new_input_shape_tensor = concat (max_rank, input_rank, ctx, in);
155
+ auto new_output_shape_tensor = expandedDimsTensor;
156
+
157
+ // Add a reshape layer to expand dims
158
+ auto shuffle = ctx->net ->addShuffle (*in);
159
+ shuffle->setInput (1 , *new_input_shape_tensor);
160
+
161
+ // Start the slicing from beginning of tensor since this is an expand layer
162
+ std::vector<int64_t > start_vec (max_rank, 0 );
163
+ nvinfer1::Dims starts_dim = util::toDims (c10::IntArrayRef (start_vec));
164
+ at::Tensor thStart = torch::tensor (util::toVec (starts_dim), torch::kInt32 );
165
+ auto starts = tensor_to_const (ctx, thStart);
166
+
167
+ // compute sizes = max(x,y).
168
+ auto sizes =
169
+ ctx->net ->addElementWise (*new_input_shape_tensor, *new_output_shape_tensor, nvinfer1::ElementWiseOperation::kMAX )
170
+ ->getOutput (0 );
171
+ nvinfer1::Dims sizes_dim{-1 , {}};
172
+ sizes_dim.nbDims = max_rank;
173
+
174
+ // Compute (x > 1 ? 1 : 0) for x in newDims, assuming positive x, using only TensorRT operations.
175
+ // min(1, sub(input_shape, 1))
176
+ torch::Tensor thOne = torch::tensor ({1 }, torch::kInt32 );
177
+ auto one_tensor = tensor_to_const (ctx, thOne);
178
+ auto x_sub_one = ctx->net ->addElementWise (*new_input_shape_tensor, *one_tensor, nvinfer1::ElementWiseOperation::kSUB )
179
+ ->getOutput (0 );
180
+ auto strides = ctx->net ->addElementWise (*one_tensor, *x_sub_one, nvinfer1::ElementWiseOperation::kMIN )->getOutput (0 );
181
+ nvinfer1::Dims strides_dim{-1 , {}};
182
+ strides_dim.nbDims = max_rank;
183
+
184
+ // Slice layer does the expansion in TRT. Desired output size is specified by sizes input at index 2.
185
+ auto slice = ctx->net ->addSlice (*shuffle->getOutput (0 ), starts_dim, sizes_dim, strides_dim);
186
+ slice->setInput (1 , *starts);
187
+ slice->setInput (2 , *sizes);
188
+ slice->setInput (3 , *strides);
189
+
190
+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], slice->getOutput (0 ));
191
+
192
+ LOG_DEBUG (" Expand layer output tensor shape: " << out_tensor->getDimensions ());
193
+
194
+ return true ;
195
+ }
196
+
79
197
auto expand_registrations TRTORCH_UNUSED =
80
198
RegisterNodeConversionPatterns ()
81
- .pattern({" aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> (Tensor(a))" ,
82
- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
83
- auto in = args[0 ].ITensor ();
84
- auto input_dims = in->getDimensions ();
85
- auto expanded_size = args[1 ].unwrapToIntList ();
86
- auto expandedDims = util::toDims (expanded_size);
87
- LOG_DEBUG (" (expand layer) Expand input from " << input_dims << " to " << expandedDims);
88
- return add_expand (ctx, n, in, expandedDims);
89
- }})
90
- .pattern({" aten::expand_as(Tensor(a) self, Tensor other) -> (Tensor(a))" ,
91
- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
92
- // TODO: Currently expand supports static shapes. Need to explore if the same code can be extended
93
- // to dynamic expansion.
94
- auto in = args[0 ].ITensor ();
95
- auto input_dims = in->getDimensions ();
96
- auto targetTensor = args[1 ].ITensor ();
97
- auto targetDims = targetTensor->getDimensions ();
98
- LOG_DEBUG (" (expand_as layer) Expand input from " << input_dims << " to " << targetDims);
99
- return add_expand (ctx, n, in, targetDims);
100
- }})
101
- .pattern({" aten::repeat(Tensor self, int[] repeats) -> (Tensor)" ,
102
- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
103
- auto in = args[0 ].ITensor ();
104
- auto input_dims = in->getDimensions ();
105
- auto repeats = args[1 ].unwrapToIntList ().vec ();
106
- TRTORCH_CHECK (
107
- static_cast <int64_t >(repeats.size ()) >= input_dims.nbDims ,
108
- " Number of repeat dimensions cannot be smaller than number of input dimensions" );
109
- auto num_expand_dims = repeats.size () - input_dims.nbDims ;
110
- if (num_expand_dims > 0 ) {
111
- nvinfer1::Dims reshape_dims;
112
- reshape_dims.nbDims = repeats.size ();
113
- for (size_t i = 0 ; i < num_expand_dims; i++) {
114
- reshape_dims.d [i] = 1 ;
115
- }
116
- for (int64_t i = 0 ; i < input_dims.nbDims ; i++) {
117
- reshape_dims.d [num_expand_dims + i] = input_dims.d [i];
118
- }
119
- // Add a reshape layer to expand dims
120
- auto reshape_layer = ctx->net ->addShuffle (*in);
121
- reshape_layer->setReshapeDimensions (reshape_dims);
122
- in = reshape_layer->getOutput (0 );
123
- LOG_DEBUG (" Input reshaped to : " << in->getDimensions () << " from " << input_dims);
124
- }
125
-
126
- LOG_DEBUG (" Repeats: " << repeats);
127
-
128
- // Concat across all repeat axes.
129
- // TODO: Implementation might not be performant. Explore other strategies to improve performance.
130
- for (int64_t i = repeats.size () - 1 ; i >= 0 ; --i) {
131
- std::vector<nvinfer1::ITensor*> tensors_vec;
132
- for (int64_t j = 0 ; j < repeats[i]; j++) {
133
- tensors_vec.push_back (in);
134
- }
135
- auto concat_layer = ctx->net ->addConcatenation (tensors_vec.data (), tensors_vec.size ());
136
- concat_layer->setAxis (i);
137
- in = concat_layer->getOutput (0 );
138
- }
139
-
140
- auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], in);
141
-
142
- LOG_DEBUG (" Repeat layer output tensor shape: " << out->getDimensions ());
143
-
144
- return true ;
145
- }});
199
+ .pattern(
200
+ {" aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> (Tensor(a))" ,
201
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
202
+ auto in = args[0 ].ITensor ();
203
+ auto input_dims = in->getDimensions ();
204
+ auto expanded_size = args[1 ].unwrapToIntList ();
205
+ auto expandedDims = util::toDims (expanded_size);
206
+ LOG_DEBUG (" (expand layer) Expand input from " << input_dims << " to " << expandedDims);
207
+ if (ctx->input_is_dynamic ) {
208
+ at::Tensor thExpanded_size = torch::tensor (expanded_size.vec (), torch::kInt32 );
209
+ auto expandedDimsTensor = tensor_to_const (ctx, thExpanded_size);
210
+ return add_expand_dynamic (ctx, n, in, expandedDimsTensor, expandedDims, true );
211
+ } else {
212
+ return add_expand (ctx, n, in, expandedDims);
213
+ }
214
+ }})
215
+ .pattern(
216
+ {" aten::expand_as(Tensor(a) self, Tensor other) -> (Tensor(a))" ,
217
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
218
+ auto in = args[0 ].ITensor ();
219
+ auto input_dims = in->getDimensions ();
220
+ auto targetTensor = args[1 ].ITensor ();
221
+ auto targetDims = targetTensor->getDimensions ();
222
+ LOG_DEBUG (" (expand_as layer) Expand input from " << input_dims << " to " << targetDims);
223
+ if (ctx->input_is_dynamic ) {
224
+ return add_expand_dynamic (
225
+ ctx, n, in, ctx->net ->addShape (*targetTensor)->getOutput (0 ), targetDims, false );
226
+ } else {
227
+ return add_expand (ctx, n, in, targetDims);
228
+ }
229
+ }})
230
+ .pattern(
231
+ {" aten::repeat(Tensor self, int[] repeats) -> (Tensor)" ,
232
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
233
+ auto in = args[0 ].ITensor ();
234
+ auto input_dims = in->getDimensions ();
235
+ auto repeats = args[1 ].unwrapToIntList ().vec ();
236
+ int repeats_rank = repeats.size ();
237
+ TRTORCH_CHECK (
238
+ repeats_rank >= input_dims.nbDims ,
239
+ " Number of repeat dimensions cannot be smaller than number of input dimensions" );
240
+ auto num_expand_dims = repeats_rank - input_dims.nbDims ;
241
+
242
+ if (ctx->input_is_dynamic ) {
243
+ int input_rank = input_dims.nbDims ;
244
+ int output_rank = repeats_rank;
245
+ auto new_input_shape_tensor = concat (output_rank, input_rank, ctx, in);
246
+
247
+ // Add a reshape layer to expand dims
248
+ auto shuffle = ctx->net ->addShuffle (*in);
249
+ shuffle->setInput (1 , *new_input_shape_tensor);
250
+ in = shuffle->getOutput (0 );
251
+ } else {
252
+ if (num_expand_dims > 0 ) {
253
+ nvinfer1::Dims reshape_dims;
254
+ reshape_dims.nbDims = repeats.size ();
255
+ for (int i = 0 ; i < num_expand_dims; i++) {
256
+ reshape_dims.d [i] = 1 ;
257
+ }
258
+ for (int i = 0 ; i < input_dims.nbDims ; i++) {
259
+ reshape_dims.d [num_expand_dims + i] = input_dims.d [i];
260
+ }
261
+ // Add a reshape layer to expand dims
262
+ auto reshape_layer = ctx->net ->addShuffle (*in);
263
+ reshape_layer->setReshapeDimensions (reshape_dims);
264
+ in = reshape_layer->getOutput (0 );
265
+ LOG_DEBUG (" Input reshaped to : " << in->getDimensions () << " from " << input_dims);
266
+ }
267
+ LOG_DEBUG (" Repeats: " << repeats);
268
+ }
269
+
270
+ // Concat across all repeat axes.
271
+ // TODO: Implementation might not be performant. Explore other strategies to improve performance.
272
+ for (int i = repeats.size () - 1 ; i >= 0 ; --i) {
273
+ std::vector<nvinfer1::ITensor*> tensors_vec;
274
+ for (int j = 0 ; j < repeats[i]; j++) {
275
+ tensors_vec.push_back (in);
276
+ }
277
+ auto concat_layer = ctx->net ->addConcatenation (tensors_vec.data (), tensors_vec.size ());
278
+ concat_layer->setAxis (i);
279
+ in = concat_layer->getOutput (0 );
280
+ }
281
+
282
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], in);
283
+
284
+ LOG_DEBUG (" Repeat layer output tensor shape: " << out->getDimensions ());
285
+ return true ;
286
+ }});
146
287
147
288
} // namespace
148
289
} // namespace impl
149
290
} // namespace converters
150
291
} // namespace conversion
151
292
} // namespace core
152
- } // namespace trtorch
293
+ } // namespace trtorch
0 commit comments