@@ -196,94 +196,91 @@ bool add_expand_dynamic(
196
196
197
197
auto expand_registrations TRTORCH_UNUSED =
198
198
RegisterNodeConversionPatterns ()
199
- .pattern(
200
- {" aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> (Tensor(a))" ,
201
- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
202
- auto in = args[0 ].ITensor ();
203
- auto input_dims = in->getDimensions ();
204
- auto expanded_size = args[1 ].unwrapToIntList ();
205
- auto expandedDims = util::toDims (expanded_size);
206
- LOG_DEBUG (" (expand layer) Expand input from " << input_dims << " to " << expandedDims);
207
- if (ctx->input_is_dynamic ) {
208
- at::Tensor thExpanded_size = torch::tensor (expanded_size.vec (), torch::kInt32 );
209
- auto expandedDimsTensor = tensor_to_const (ctx, thExpanded_size);
210
- return add_expand_dynamic (ctx, n, in, expandedDimsTensor, expandedDims, true );
211
- } else {
212
- return add_expand (ctx, n, in, expandedDims);
213
- }
214
- }})
215
- .pattern(
216
- {" aten::expand_as(Tensor(a) self, Tensor other) -> (Tensor(a))" ,
217
- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
218
- auto in = args[0 ].ITensor ();
219
- auto input_dims = in->getDimensions ();
220
- auto targetTensor = args[1 ].ITensor ();
221
- auto targetDims = targetTensor->getDimensions ();
222
- LOG_DEBUG (" (expand_as layer) Expand input from " << input_dims << " to " << targetDims);
223
- if (ctx->input_is_dynamic ) {
224
- return add_expand_dynamic (
225
- ctx, n, in, ctx->net ->addShape (*targetTensor)->getOutput (0 ), targetDims, false );
226
- } else {
227
- return add_expand (ctx, n, in, targetDims);
228
- }
229
- }})
230
- .pattern(
231
- {" aten::repeat(Tensor self, int[] repeats) -> (Tensor)" ,
232
- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
233
- auto in = args[0 ].ITensor ();
234
- auto input_dims = in->getDimensions ();
235
- auto repeats = args[1 ].unwrapToIntList ().vec ();
236
- int repeats_rank = repeats.size ();
237
- TRTORCH_CHECK (
238
- repeats_rank >= input_dims.nbDims ,
239
- " Number of repeat dimensions cannot be smaller than number of input dimensions" );
240
- auto num_expand_dims = repeats_rank - input_dims.nbDims ;
199
+ .pattern({" aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> (Tensor(a))" ,
200
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
201
+ auto in = args[0 ].ITensor ();
202
+ auto input_dims = in->getDimensions ();
203
+ auto expanded_size = args[1 ].unwrapToIntList ();
204
+ auto expandedDims = util::toDims (expanded_size);
205
+ LOG_DEBUG (" (expand layer) Expand input from " << input_dims << " to " << expandedDims);
206
+ if (ctx->input_is_dynamic ) {
207
+ at::Tensor thExpanded_size = torch::tensor (expanded_size.vec (), torch::kInt32 );
208
+ auto expandedDimsTensor = tensor_to_const (ctx, thExpanded_size);
209
+ return add_expand_dynamic (ctx, n, in, expandedDimsTensor, expandedDims, true );
210
+ } else {
211
+ return add_expand (ctx, n, in, expandedDims);
212
+ }
213
+ }})
214
+ .pattern({" aten::expand_as(Tensor(a) self, Tensor other) -> (Tensor(a))" ,
215
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
216
+ auto in = args[0 ].ITensor ();
217
+ auto input_dims = in->getDimensions ();
218
+ auto targetTensor = args[1 ].ITensor ();
219
+ auto targetDims = targetTensor->getDimensions ();
220
+ LOG_DEBUG (" (expand_as layer) Expand input from " << input_dims << " to " << targetDims);
221
+ if (ctx->input_is_dynamic ) {
222
+ return add_expand_dynamic (
223
+ ctx, n, in, ctx->net ->addShape (*targetTensor)->getOutput (0 ), targetDims, false );
224
+ } else {
225
+ return add_expand (ctx, n, in, targetDims);
226
+ }
227
+ }})
228
+ .pattern({" aten::repeat(Tensor self, int[] repeats) -> (Tensor)" ,
229
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
230
+ auto in = args[0 ].ITensor ();
231
+ auto input_dims = in->getDimensions ();
232
+ auto repeats = args[1 ].unwrapToIntList ().vec ();
233
+ int repeats_rank = repeats.size ();
234
+ TRTORCH_CHECK (
235
+ repeats_rank >= input_dims.nbDims ,
236
+ " Number of repeat dimensions cannot be smaller than number of input dimensions" );
237
+ auto num_expand_dims = repeats_rank - input_dims.nbDims ;
241
238
242
- if (ctx->input_is_dynamic ) {
243
- int input_rank = input_dims.nbDims ;
244
- int output_rank = repeats_rank;
245
- auto new_input_shape_tensor = concat (output_rank, input_rank, ctx, in);
239
+ if (ctx->input_is_dynamic ) {
240
+ int input_rank = input_dims.nbDims ;
241
+ int output_rank = repeats_rank;
242
+ auto new_input_shape_tensor = concat (output_rank, input_rank, ctx, in);
246
243
247
- // Add a reshape layer to expand dims
248
- auto shuffle = ctx->net ->addShuffle (*in);
249
- shuffle->setInput (1 , *new_input_shape_tensor);
250
- in = shuffle->getOutput (0 );
251
- } else {
252
- if (num_expand_dims > 0 ) {
253
- nvinfer1::Dims reshape_dims;
254
- reshape_dims.nbDims = repeats.size ();
255
- for (int i = 0 ; i < num_expand_dims; i++) {
256
- reshape_dims.d [i] = 1 ;
257
- }
258
- for (int i = 0 ; i < input_dims.nbDims ; i++) {
259
- reshape_dims.d [num_expand_dims + i] = input_dims.d [i];
260
- }
261
- // Add a reshape layer to expand dims
262
- auto reshape_layer = ctx->net ->addShuffle (*in);
263
- reshape_layer->setReshapeDimensions (reshape_dims);
264
- in = reshape_layer->getOutput (0 );
265
- LOG_DEBUG (" Input reshaped to : " << in->getDimensions () << " from " << input_dims);
266
- }
267
- LOG_DEBUG (" Repeats: " << repeats);
268
- }
244
+ // Add a reshape layer to expand dims
245
+ auto shuffle = ctx->net ->addShuffle (*in);
246
+ shuffle->setInput (1 , *new_input_shape_tensor);
247
+ in = shuffle->getOutput (0 );
248
+ } else {
249
+ if (num_expand_dims > 0 ) {
250
+ nvinfer1::Dims reshape_dims;
251
+ reshape_dims.nbDims = repeats.size ();
252
+ for (int i = 0 ; i < num_expand_dims; i++) {
253
+ reshape_dims.d [i] = 1 ;
254
+ }
255
+ for (int i = 0 ; i < input_dims.nbDims ; i++) {
256
+ reshape_dims.d [num_expand_dims + i] = input_dims.d [i];
257
+ }
258
+ // Add a reshape layer to expand dims
259
+ auto reshape_layer = ctx->net ->addShuffle (*in);
260
+ reshape_layer->setReshapeDimensions (reshape_dims);
261
+ in = reshape_layer->getOutput (0 );
262
+ LOG_DEBUG (" Input reshaped to : " << in->getDimensions () << " from " << input_dims);
263
+ }
264
+ LOG_DEBUG (" Repeats: " << repeats);
265
+ }
269
266
270
- // Concat across all repeat axes.
271
- // TODO: Implementation might not be performant. Explore other strategies to improve performance.
272
- for (int i = repeats.size () - 1 ; i >= 0 ; --i) {
273
- std::vector<nvinfer1::ITensor*> tensors_vec;
274
- for (int j = 0 ; j < repeats[i]; j++) {
275
- tensors_vec.push_back (in);
276
- }
277
- auto concat_layer = ctx->net ->addConcatenation (tensors_vec.data (), tensors_vec.size ());
278
- concat_layer->setAxis (i);
279
- in = concat_layer->getOutput (0 );
280
- }
267
+ // Concat across all repeat axes.
268
+ // TODO: Implementation might not be performant. Explore other strategies to improve performance.
269
+ for (int i = repeats.size () - 1 ; i >= 0 ; --i) {
270
+ std::vector<nvinfer1::ITensor*> tensors_vec;
271
+ for (int j = 0 ; j < repeats[i]; j++) {
272
+ tensors_vec.push_back (in);
273
+ }
274
+ auto concat_layer = ctx->net ->addConcatenation (tensors_vec.data (), tensors_vec.size ());
275
+ concat_layer->setAxis (i);
276
+ in = concat_layer->getOutput (0 );
277
+ }
281
278
282
- auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], in);
279
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], in);
283
280
284
- LOG_DEBUG (" Repeat layer output tensor shape: " << out->getDimensions ());
285
- return true ;
286
- }});
281
+ LOG_DEBUG (" Repeat layer output tensor shape: " << out->getDimensions ());
282
+ return true ;
283
+ }});
287
284
288
285
} // namespace
289
286
} // namespace impl
0 commit comments