@@ -253,6 +253,42 @@ auto select_registrations TORCHTRT_UNUSED =
253
253
return true ;
254
254
}
255
255
}})
256
+ .pattern({" aten::index.Tensor(Tensor self, Tensor?[] indices) -> (Tensor)" ,
257
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
258
+ auto in = args[0 ].ITensorOrFreeze (ctx);
259
+ auto ts = args[1 ].IValue ()->toListRef ();
260
+
261
+ std::vector<nvinfer1::ITensor*> tensors;
262
+ for (auto t : ts) {
263
+ if (t.isTensor ()) {
264
+ auto torch_tensor = t.toTensor ();
265
+ tensors.push_back (tensor_to_const (ctx, torch_tensor));
266
+ } else {
267
+ auto cont = t.toCustomClass <TensorContainer>();
268
+ tensors.push_back (cont->tensor ());
269
+ }
270
+ }
271
+
272
+ TORCHTRT_CHECK (
273
+ tensors.size () == 1 ,
274
+ " This version of Torch-TensorRT only supports one index in aten::index.Tensor" );
275
+ auto indicesTensor = tensors[0 ];
276
+ // Set datatype for indices tensor to INT32
277
+ auto identity = ctx->net ->addIdentity (*indicesTensor);
278
+ identity->setOutputType (0 , nvinfer1::DataType::kINT32 );
279
+ indicesTensor = identity->getOutput (0 );
280
+
281
+ // IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices
282
+ // from
283
+ auto gather_layer = ctx->net ->addGather (*in, *indicesTensor, 0 );
284
+ TORCHTRT_CHECK (gather_layer, " Unable to create gather layer from node: " << *n);
285
+ auto gather_out = gather_layer->getOutput (0 );
286
+
287
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], gather_out);
288
+
289
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
290
+ return true ;
291
+ }})
256
292
.pattern(
257
293
{" aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)" ,
258
294
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
0 commit comments