@@ -19,15 +19,23 @@ namespace conversion {
19
19
namespace evaluators {
20
20
namespace {
21
21
22
- nvinfer1::ITensor* index_layer (){
23
-
22
+ nvinfer1::ITensor* index_layer (ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* input_tensor, int64_t index){
23
+ // index to access needs to be an at::Tensor
24
+ at::Tensor indices = torch::tensor ({index}).to (torch::kI32 );
25
+ auto indices_out = torch_tensorrt::core::conversion::converters::tensor_to_const (ctx, indices);
26
+
27
+ auto gather_layer = ctx->net ->addGather (*input_tensor, *indices_out, 0 );
28
+ TORCHTRT_CHECK (gather_layer, " Unable to create gather layer from node: " << *n);
29
+ auto indexed_tensor = gather_layer->getOutput (0 );
30
+ return indexed_tensor;
24
31
}
25
32
26
33
c10::IValue dynamic_size_layer (ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args){
27
34
LOG_DEBUG (" Using dynamic version of aten::size evaluator" );
28
35
auto in = args.at (n->input (0 )).ITensorOrFreeze (ctx);
29
36
LOG_DEBUG (" Input dimensions: " << in->getDimensions ());
30
37
auto shape_layer = ctx->net ->addShape (*in);
38
+ TORCHTRT_CHECK (shape_layer, " Unable to create shape layer from node: " << *n);
31
39
auto shape_1d_tensor = shape_layer->getOutput (0 );
32
40
33
41
if (n->inputs ().size () != 1 ){
@@ -36,15 +44,9 @@ c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kw
36
44
// Handle negative axis by refering to nbDims of input Tensor
37
45
dim = dim < 0 ? dim + maxDim : dim;
38
46
LOG_DEBUG (" Dimension to select: " << dim);
39
-
40
- // index to access needs to be an at::Tensor
41
- at::Tensor indices = torch::tensor ({dim}).to (torch::kI32 );
42
- auto indices_out = torch_tensorrt::core::conversion::converters::tensor_to_const (ctx, indices);
43
-
44
- auto gather_layer = ctx->net ->addGather (*shape_1d_tensor, *indices_out, 0 );
45
- shape_1d_tensor = gather_layer->getOutput (0 );
47
+ shape_1d_tensor = index_layer (ctx, n, shape_1d_tensor, dim);
46
48
}
47
-
49
+
48
50
LOG_DEBUG (" Output tensor shape: " << shape_1d_tensor->getDimensions ());
49
51
50
52
auto tensor_holder = TensorContainer ();
@@ -364,13 +366,13 @@ auto aten_registrations TORCHTRT_UNUSED =
364
366
TORCHTRT_CHECK (
365
367
normalized_idx >= 0 || normalized_idx < list_size, " List index out of range (aten::__getitem__)" );
366
368
return list.get (normalized_idx);
367
- } elif (list_input.isITensor ()){
368
- return dynamic_size_layer (ctx, n, args);
369
+ } else if (list_input.isITensor ()){
370
+ auto indexed_tensor = index_layer (ctx, n, list_input.ITensorOrFreeze (ctx), idx);
371
+ auto tensor_holder = TensorContainer ();
372
+ tensor_holder.hold_tensor (indexed_tensor);
373
+ auto indexed_ivalue = c10::IValue (std::move (c10::make_intrusive<TensorContainer>(tensor_holder)));
374
+ return indexed_ivalue;
369
375
}
370
-
371
-
372
-
373
-
374
376
},
375
377
EvalOptions ().validSchemas ({
376
378
" aten::__getitem__.t(t[](a) list, int idx) -> (t(*))" ,
0 commit comments