Skip to content

Commit 6701759

Browse files
committed
chore: clarify aten::index.Tensor only supports one index in this version
Signed-off-by: Ruoqian Guo <[email protected]>
1 parent ede9d13 commit 6701759

File tree

1 file changed

+37
-33
lines changed

1 file changed

+37
-33
lines changed

core/conversion/converters/impl/select.cpp

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -253,42 +253,46 @@ auto select_registrations TORCHTRT_UNUSED =
253253
return true;
254254
}
255255
}})
256-
.pattern({"aten::index.Tensor(Tensor self, Tensor?[] indices) -> (Tensor)",
257-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
258-
auto in = args[0].ITensorOrFreeze(ctx);
259-
auto ts = args[1].IValue()->toListRef();
260-
261-
std::vector<nvinfer1::ITensor*> tensors;
262-
for (auto t : ts) {
263-
if (t.isTensor()) {
264-
auto torch_tensor = t.toTensor();
265-
tensors.push_back(tensor_to_const(ctx, torch_tensor));
266-
} else {
267-
auto cont = t.toCustomClass<TensorContainer>();
268-
tensors.push_back(cont->tensor());
269-
}
270-
}
256+
.pattern(
257+
{"aten::index.Tensor(Tensor self, Tensor?[] indices) -> (Tensor)",
258+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
259+
auto in = args[0].ITensorOrFreeze(ctx);
260+
auto ts = args[1].IValue()->toListRef();
261+
262+
std::vector<nvinfer1::ITensor*> tensors;
263+
for (auto t : ts) {
264+
if (t.isTensor()) {
265+
auto torch_tensor = t.toTensor();
266+
tensors.push_back(tensor_to_const(ctx, torch_tensor));
267+
} else {
268+
auto cont = t.toCustomClass<TensorContainer>();
269+
tensors.push_back(cont->tensor());
270+
}
271+
}
271272

272-
TORCHTRT_CHECK(
273-
tensors.size() == 1,
274-
"This version of Torch-TensorRT only supports one index in aten::index.Tensor");
275-
auto indicesTensor = tensors[0];
276-
// Set datatype for indices tensor to INT32
277-
auto identity = ctx->net->addIdentity(*indicesTensor);
278-
identity->setOutputType(0, nvinfer1::DataType::kINT32);
279-
indicesTensor = identity->getOutput(0);
280-
281-
// IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices
282-
// from
283-
auto gather_layer = ctx->net->addGather(*in, *indicesTensor, 0);
284-
TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n);
285-
auto gather_out = gather_layer->getOutput(0);
273+
// In TorchScript, aten::index.Tensor indexes the self tensor along its each dimension by several
274+
// indexes. In this version of Torch-TensorRT, it can only receive one index tensor which means it only
275+
// indexes the self tensor along dimension 0.
276+
TORCHTRT_CHECK(
277+
tensors.size() == 1,
278+
"In this version of Torch-TensorRT, aten::index.Tensor can only receive one index tensor which means it only indexes the self tensor along dimension 0.");
279+
auto indicesTensor = tensors[0];
280+
// Set datatype for indices tensor to INT32
281+
auto identity = ctx->net->addIdentity(*indicesTensor);
282+
identity->setOutputType(0, nvinfer1::DataType::kINT32);
283+
indicesTensor = identity->getOutput(0);
286284

287-
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], gather_out);
285+
// IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices
286+
// from
287+
auto gather_layer = ctx->net->addGather(*in, *indicesTensor, 0);
288+
TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n);
289+
auto gather_out = gather_layer->getOutput(0);
288290

289-
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
290-
return true;
291-
}})
291+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], gather_out);
292+
293+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
294+
return true;
295+
}})
292296
.pattern(
293297
{"aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)",
294298
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {

0 commit comments

Comments
 (0)