Skip to content

Commit ede9d13

Browse files
committed
support aten::index.Tensor
Signed-off-by: Ruoqian Guo <[email protected]>
1 parent d7c2794 commit ede9d13

File tree

2 files changed

+62
-0
lines changed

2 files changed

+62
-0
lines changed

core/conversion/converters/impl/select.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,42 @@ auto select_registrations TORCHTRT_UNUSED =
253253
return true;
254254
}
255255
}})
256+
.pattern({"aten::index.Tensor(Tensor self, Tensor?[] indices) -> (Tensor)",
257+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
258+
auto in = args[0].ITensorOrFreeze(ctx);
259+
auto ts = args[1].IValue()->toListRef();
260+
261+
std::vector<nvinfer1::ITensor*> tensors;
262+
for (auto t : ts) {
263+
if (t.isTensor()) {
264+
auto torch_tensor = t.toTensor();
265+
tensors.push_back(tensor_to_const(ctx, torch_tensor));
266+
} else {
267+
auto cont = t.toCustomClass<TensorContainer>();
268+
tensors.push_back(cont->tensor());
269+
}
270+
}
271+
272+
TORCHTRT_CHECK(
273+
tensors.size() == 1,
274+
"This version of Torch-TensorRT only supports one index in aten::index.Tensor");
275+
auto indicesTensor = tensors[0];
276+
// Set datatype for indices tensor to INT32
277+
auto identity = ctx->net->addIdentity(*indicesTensor);
278+
identity->setOutputType(0, nvinfer1::DataType::kINT32);
279+
indicesTensor = identity->getOutput(0);
280+
281+
// IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices
282+
// from
283+
auto gather_layer = ctx->net->addGather(*in, *indicesTensor, 0);
284+
TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n);
285+
auto gather_out = gather_layer->getOutput(0);
286+
287+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], gather_out);
288+
289+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
290+
return true;
291+
}})
256292
.pattern(
257293
{"aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)",
258294
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {

tests/core/conversion/converters/test_select.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,3 +541,29 @@ TEST(Converters, ATenMaskedFillZerosConvertsCorrectly) {
541541
ASSERT_TRUE(
542542
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
543543
}
544+
545+
TEST(Converters, ATenIndexTensorConvertsCorrectly) {
546+
const auto graph = R"IR(
547+
graph(%x.1 : Tensor,
548+
%index : Tensor):
549+
%18 : Tensor?[] = prim::ListConstruct(%index)
550+
%19 : Tensor = aten::index(%x.1, %18)
551+
return (%19))IR";
552+
553+
auto g = std::make_shared<torch::jit::Graph>();
554+
torch::jit::parseIR(graph, g.get());
555+
556+
auto in1 = at::randint(1, 10, {5, 10}, {at::kCUDA});
557+
auto in2 = at::full({2}, 4, {at::kCUDA});
558+
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
559+
auto in2_trt = at::full({2}, 4, {options});
560+
561+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
562+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2});
563+
564+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
565+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2_trt});
566+
567+
ASSERT_TRUE(
568+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
569+
}

0 commit comments

Comments
 (0)