Skip to content

Commit 207b1b1

Browse files
authored
Merge pull request #921 from ruoqianguo/aten_index_tensor
support aten::index.Tensor
2 parents d93659d + 6701759 commit 207b1b1

File tree

2 files changed

+66
-0
lines changed

2 files changed

+66
-0
lines changed

core/conversion/converters/impl/select.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,46 @@ auto select_registrations TORCHTRT_UNUSED =
253253
return true;
254254
}
255255
}})
256+
.pattern(
257+
{"aten::index.Tensor(Tensor self, Tensor?[] indices) -> (Tensor)",
258+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
259+
auto in = args[0].ITensorOrFreeze(ctx);
260+
auto ts = args[1].IValue()->toListRef();
261+
262+
std::vector<nvinfer1::ITensor*> tensors;
263+
for (auto t : ts) {
264+
if (t.isTensor()) {
265+
auto torch_tensor = t.toTensor();
266+
tensors.push_back(tensor_to_const(ctx, torch_tensor));
267+
} else {
268+
auto cont = t.toCustomClass<TensorContainer>();
269+
tensors.push_back(cont->tensor());
270+
}
271+
}
272+
273+
// In TorchScript, aten::index.Tensor indexes the self tensor along its each dimension by several
274+
// indexes. In this version of Torch-TensorRT, it can only receive one index tensor which means it only
275+
// indexes the self tensor along dimension 0.
276+
TORCHTRT_CHECK(
277+
tensors.size() == 1,
278+
"In this version of Torch-TensorRT, aten::index.Tensor can only receive one index tensor which means it only indexes the self tensor along dimension 0.");
279+
auto indicesTensor = tensors[0];
280+
// Set datatype for indices tensor to INT32
281+
auto identity = ctx->net->addIdentity(*indicesTensor);
282+
identity->setOutputType(0, nvinfer1::DataType::kINT32);
283+
indicesTensor = identity->getOutput(0);
284+
285+
// IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices
286+
// from
287+
auto gather_layer = ctx->net->addGather(*in, *indicesTensor, 0);
288+
TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n);
289+
auto gather_out = gather_layer->getOutput(0);
290+
291+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], gather_out);
292+
293+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
294+
return true;
295+
}})
256296
.pattern(
257297
{"aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)",
258298
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {

tests/core/conversion/converters/test_select.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,3 +541,29 @@ TEST(Converters, ATenMaskedFillZerosConvertsCorrectly) {
541541
ASSERT_TRUE(
542542
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
543543
}
544+
545+
TEST(Converters, ATenIndexTensorConvertsCorrectly) {
546+
const auto graph = R"IR(
547+
graph(%x.1 : Tensor,
548+
%index : Tensor):
549+
%18 : Tensor?[] = prim::ListConstruct(%index)
550+
%19 : Tensor = aten::index(%x.1, %18)
551+
return (%19))IR";
552+
553+
auto g = std::make_shared<torch::jit::Graph>();
554+
torch::jit::parseIR(graph, g.get());
555+
556+
auto in1 = at::randint(1, 10, {5, 10}, {at::kCUDA});
557+
auto in2 = at::full({2}, 4, {at::kCUDA});
558+
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
559+
auto in2_trt = at::full({2}, 4, {options});
560+
561+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
562+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2});
563+
564+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
565+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2_trt});
566+
567+
ASSERT_TRUE(
568+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
569+
}

0 commit comments

Comments
 (0)