@@ -271,37 +271,229 @@ auto select_registrations TORCHTRT_UNUSED =
271
271
auto ts = args[1 ].IValue ()->toListRef ();
272
272
273
273
std::vector<nvinfer1::ITensor*> tensors;
274
- for (auto t : ts) {
274
+ std::vector<int32_t > adv_idx_indices;
275
+ for (auto i = 0 ; i < ts.size (); i++) {
276
+ auto t = ts[i];
275
277
if (t.isTensor ()) {
276
- auto torch_tensor = t.toTensor ();
278
+ auto torch_tensor = t.toTensor (). to (torch:: kInt32 ) ;
277
279
tensors.push_back (tensor_to_const (ctx, torch_tensor));
280
+ adv_idx_indices.push_back (i);
278
281
} else {
279
- auto cont = t.toCustomClass <TensorContainer>();
280
- tensors.push_back (cont->tensor ());
282
+ // IValue
283
+ if (!t.isNone ()) {
284
+ adv_idx_indices.push_back (i);
285
+ auto cont = t.toCustomClass <TensorContainer>();
286
+ // Set datatype for indices tensor to INT32
287
+ auto identity = ctx->net ->addIdentity (*cont->tensor ());
288
+ identity->setOutputType (0 , nvinfer1::DataType::kINT32 );
289
+ tensors.push_back (identity->getOutput (0 ));
290
+ }
281
291
}
282
292
}
283
293
284
- // In TorchScript, aten::index.Tensor indexes the self tensor along its each dimension by several
285
- // indexes. In this version of Torch-TensorRT, it can only receive one index tensor which means it only
286
- // indexes the self tensor along dimension 0.
287
- TORCHTRT_CHECK (
288
- tensors.size () == 1 ,
289
- " In this version of Torch-TensorRT, aten::index.Tensor can only receive one index tensor which means it only indexes the self tensor along dimension 0." );
290
- auto indicesTensor = tensors[0 ];
291
- // Set datatype for indices tensor to INT32
292
- auto identity = ctx->net ->addIdentity (*indicesTensor);
293
- identity->setOutputType (0 , nvinfer1::DataType::kINT32 );
294
- indicesTensor = identity->getOutput (0 );
294
+ if (tensors.size () == 0 ) {
295
+ auto identity_out = ctx->net ->addIdentity (*in)->getOutput (0 );
296
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], identity_out);
297
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
298
+ } else if (tensors.size () == 1 ) {
299
+ auto indicesTensor = tensors[0 ];
300
+ // Set datatype for indices tensor to INT32
301
+ auto identity = ctx->net ->addIdentity (*indicesTensor);
302
+ identity->setOutputType (0 , nvinfer1::DataType::kINT32 );
303
+ indicesTensor = identity->getOutput (0 );
304
+
305
+ // IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices
306
+ // from
307
+ auto gather_layer = ctx->net ->addGather (*in, *indicesTensor, 0 );
308
+ TORCHTRT_CHECK (gather_layer, " Unable to create gather layer from node: " << *n);
309
+ auto gather_out = gather_layer->getOutput (0 );
310
+
311
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], gather_out);
312
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
313
+ } else {
314
+ auto inDims = in->getDimensions ();
315
+ int rank = inDims.nbDims ;
316
+ LOG_WARNING (" If indices include negative values, the exported graph will produce incorrect results." );
317
+ int adv_idx_count = adv_idx_indices.size ();
318
+ auto in_shape_itensor = ctx->net ->addShape (*in)->getOutput (0 );
319
+
320
+ std::vector<nvinfer1::ITensor*> dim_tensor_list;
321
+ for (int i = 0 ; i < rank; i++) {
322
+ auto dim_tensor =
323
+ ctx->net
324
+ ->addGather (*in_shape_itensor, *tensor_to_const (ctx, torch::tensor ({i}, torch::kInt32 )), 0 )
325
+ ->getOutput (0 );
326
+ dim_tensor_list.push_back (dim_tensor);
327
+ }
295
328
296
- // IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices
297
- // from
298
- auto gather_layer = ctx->net ->addGather (*in, *indicesTensor, 0 );
299
- TORCHTRT_CHECK (gather_layer, " Unable to create gather layer from node: " << *n);
300
- auto gather_out = gather_layer->getOutput (0 );
329
+ // t: [x_1, y_1, y_2, ..., x_m, ..., y_n] -> t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n],
330
+ // where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes
331
+ // for ":".
332
+ auto in_transpose_layer = ctx->net ->addShuffle (*in);
333
+ TORCHTRT_CHECK (in_transpose_layer, " Unable to create shuffle layer from node: " << *n);
334
+ nvinfer1::Permutation permute;
335
+ std::vector<int32_t > new_order;
336
+ for (int i = 0 ; i < adv_idx_count; i++) {
337
+ new_order.push_back (adv_idx_indices[i]);
338
+ }
339
+ for (int i = 0 ; i < rank; i++) {
340
+ if (std::find (adv_idx_indices.begin (), adv_idx_indices.end (), i) == adv_idx_indices.end ()) {
341
+ new_order.push_back (i);
342
+ }
343
+ }
344
+ std::copy (new_order.begin (), new_order.end (), permute.order );
345
+ in_transpose_layer->setSecondTranspose (permute);
346
+ auto shuffle_out = in_transpose_layer->getOutput (0 );
347
+
348
+ // t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n] -> t: [x_1*x_2* ...*x_m, y_1*y_2* ...*y_n]
349
+ nvinfer1::ITensor* flatten_tensor = NULL ;
350
+ {
351
+ auto shuffle_shape_tensor = ctx->net ->addShape (*shuffle_out)->getOutput (0 );
352
+ auto d0 = tensor_to_const (ctx, torch::tensor ({1 }, torch::kInt32 ));
353
+ for (int i = 0 ; i < adv_idx_count; i++) {
354
+ auto dim_tensor =
355
+ ctx->net
356
+ ->addGather (
357
+ *shuffle_shape_tensor, *tensor_to_const (ctx, torch::tensor ({i}, torch::kInt32 )), 0 )
358
+ ->getOutput (0 );
359
+ d0 = add_elementwise (
360
+ ctx,
361
+ nvinfer1::ElementWiseOperation::kPROD ,
362
+ d0,
363
+ dim_tensor,
364
+ std::string (" compute_dim0_" ) + std::to_string (i))
365
+ ->getOutput (0 );
366
+ }
301
367
302
- auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], gather_out);
368
+ auto d1 = tensor_to_const (ctx, torch::tensor ({1 }, torch::kInt32 ));
369
+ for (int i = adv_idx_count; i < rank; i++) {
370
+ auto dim_tensor =
371
+ ctx->net
372
+ ->addGather (
373
+ *shuffle_shape_tensor, *tensor_to_const (ctx, torch::tensor ({i}, torch::kInt32 )), 0 )
374
+ ->getOutput (0 );
375
+ d1 = add_elementwise (
376
+ ctx,
377
+ nvinfer1::ElementWiseOperation::kPROD ,
378
+ d1,
379
+ dim_tensor,
380
+ std::string (" compute_dim1_" ) + std::to_string (i))
381
+ ->getOutput (0 );
382
+ }
303
383
304
- LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
384
+ std::vector<nvinfer1::ITensor*> concat_tensors;
385
+ concat_tensors.push_back (d0);
386
+ concat_tensors.push_back (d1);
387
+ auto concat_layer = ctx->net ->addConcatenation (concat_tensors.data (), concat_tensors.size ());
388
+
389
+ auto shuffle = ctx->net ->addShuffle (*shuffle_out);
390
+ shuffle->setInput (1 , *concat_layer->getOutput (0 ));
391
+ flatten_tensor = shuffle->getOutput (0 );
392
+ LOG_DEBUG (flatten_tensor->getDimensions ());
393
+ }
394
+
395
+ // tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
396
+ // j dimension of input x.
397
+ nvinfer1::ITensor* multiplier = dim_tensor_list[adv_idx_indices[adv_idx_count - 1 ]];
398
+ nvinfer1::ITensor* cum_adv_index = tensors[adv_idx_count - 1 ];
399
+ for (int i = adv_idx_count - 2 ; i >= 0 ; i--) {
400
+ nvinfer1::ITensor* adv_index = add_elementwise (
401
+ ctx,
402
+ nvinfer1::ElementWiseOperation::kPROD ,
403
+ tensors[i],
404
+ multiplier,
405
+ std::string (" adv_index_" ) + std::to_string (i))
406
+ ->getOutput (0 );
407
+ cum_adv_index = add_elementwise (
408
+ ctx,
409
+ nvinfer1::ElementWiseOperation::kSUM ,
410
+ cum_adv_index,
411
+ adv_index,
412
+ std::string (" cum_adv_index_" ) + std::to_string (i))
413
+ ->getOutput (0 );
414
+ multiplier = add_elementwise (
415
+ ctx,
416
+ nvinfer1::ElementWiseOperation::kPROD ,
417
+ multiplier,
418
+ dim_tensor_list[adv_idx_indices[i]],
419
+ std::string (" multiplier_" ) + std::to_string (i))
420
+ ->getOutput (0 );
421
+ }
422
+
423
+ // perform gather
424
+ auto gather_out = ctx->net ->addGather (*flatten_tensor, *cum_adv_index, 0 )->getOutput (0 );
425
+
426
+ nvinfer1::ITensor* reshape_output = NULL ;
427
+ {
428
+ auto cum_adv_index_shape_tensor = ctx->net ->addShape (*cum_adv_index)->getOutput (0 );
429
+ // check if all advanced indices are consecutive.
430
+ if (adv_idx_count == (adv_idx_indices[adv_idx_count - 1 ] - adv_idx_indices[0 ] + 1 )) {
431
+ // unfold regular index axes
432
+ std::vector<nvinfer1::ITensor*> concat_tensors;
433
+ concat_tensors.push_back (tensor_to_const (ctx, torch::tensor ({-1 }, torch::kInt32 )));
434
+ for (int i = 0 ; i < rank; i++) {
435
+ if (std::find (adv_idx_indices.begin (), adv_idx_indices.end (), i) == adv_idx_indices.end ()) {
436
+ nvinfer1::ITensor* current_dim = dim_tensor_list[i];
437
+ concat_tensors.push_back (current_dim);
438
+ }
439
+ }
440
+ auto concat_layer = ctx->net ->addConcatenation (concat_tensors.data (), concat_tensors.size ());
441
+ auto regular_index_shuffle_layer = ctx->net ->addShuffle (*gather_out);
442
+ regular_index_shuffle_layer->setInput (1 , *concat_layer->getOutput (0 ));
443
+ auto unfold_tensor = regular_index_shuffle_layer->getOutput (0 );
444
+
445
+ // Transpose folded advanced indexed axis to its original location.
446
+ auto transpose_advanced_shuffle_layer = ctx->net ->addShuffle (*unfold_tensor);
447
+ nvinfer1::Permutation permute;
448
+ std::vector<int32_t > new_order;
449
+ for (int i = 1 ; i < adv_idx_indices[0 ] + 1 ; i++) {
450
+ new_order.push_back (i);
451
+ }
452
+ new_order.push_back (0 );
453
+ for (int i = adv_idx_indices[0 ] + 1 ; i < rank - adv_idx_count + 1 ; i++) {
454
+ new_order.push_back (i);
455
+ }
456
+ std::copy (new_order.begin (), new_order.end (), permute.order );
457
+ transpose_advanced_shuffle_layer->setSecondTranspose (permute);
458
+ auto shuffle_out = transpose_advanced_shuffle_layer->getOutput (0 );
459
+
460
+ // unfold advanced index axes
461
+ std::vector<nvinfer1::ITensor*> concat_final_tensors;
462
+ for (int i = 0 ; i < adv_idx_indices[0 ]; i++) {
463
+ nvinfer1::ITensor* current_dim = dim_tensor_list[i];
464
+ concat_final_tensors.push_back (current_dim);
465
+ }
466
+ concat_final_tensors.push_back (cum_adv_index_shape_tensor);
467
+ for (int i = adv_idx_indices[0 ]; i < rank; i++) {
468
+ if (std::find (adv_idx_indices.begin (), adv_idx_indices.end (), i) == adv_idx_indices.end ()) {
469
+ nvinfer1::ITensor* current_dim = dim_tensor_list[i];
470
+ concat_final_tensors.push_back (current_dim);
471
+ }
472
+ }
473
+ auto concat_final_shape_layer =
474
+ ctx->net ->addConcatenation (concat_tensors.data (), concat_tensors.size ());
475
+ auto unfold_advanced_shuffle_layer = ctx->net ->addShuffle (*shuffle_out);
476
+ unfold_advanced_shuffle_layer->setInput (1 , *concat_final_shape_layer->getOutput (0 ));
477
+ reshape_output = unfold_advanced_shuffle_layer->getOutput (0 );
478
+ } else {
479
+ std::vector<nvinfer1::ITensor*> concat_tensors;
480
+ concat_tensors.push_back (cum_adv_index_shape_tensor);
481
+ for (int i = 0 ; i < rank; i++) {
482
+ if (std::find (adv_idx_indices.begin (), adv_idx_indices.end (), i) == adv_idx_indices.end ()) {
483
+ nvinfer1::ITensor* current_dim = dim_tensor_list[i];
484
+ concat_tensors.push_back (current_dim);
485
+ }
486
+ }
487
+ auto concat_layer = ctx->net ->addConcatenation (concat_tensors.data (), concat_tensors.size ());
488
+ auto shuffle_layer = ctx->net ->addShuffle (*gather_out);
489
+ shuffle_layer->setInput (1 , *concat_layer->getOutput (0 ));
490
+ reshape_output = shuffle_layer->getOutput (0 );
491
+ }
492
+ }
493
+
494
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], reshape_output);
495
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
496
+ }
305
497
return true ;
306
498
}})
307
499
.pattern(
0 commit comments