Skip to content

Commit dd88afc

Browse files
Avoid layer name conflicts in aten::index (#1377)
1 parent 9d89f6c commit dd88afc

File tree

2 files changed

+44
-11
lines changed

2 files changed

+44
-11
lines changed

core/conversion/converters/impl/select.cpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ auto select_registrations TORCHTRT_UNUSED =
362362
nvinfer1::ElementWiseOperation::kPROD,
363363
d0,
364364
dim_tensor,
365-
std::string("compute_dim0_") + std::to_string(i))
365+
util::node_info(n) + std::string("_compute_dim0_") + std::to_string(i))
366366
->getOutput(0);
367367
}
368368

@@ -378,7 +378,7 @@ auto select_registrations TORCHTRT_UNUSED =
378378
nvinfer1::ElementWiseOperation::kPROD,
379379
d1,
380380
dim_tensor,
381-
std::string("compute_dim1_") + std::to_string(i))
381+
util::node_info(n) + std::string("_compute_dim1_") + std::to_string(i))
382382
->getOutput(0);
383383
}
384384

@@ -398,26 +398,27 @@ auto select_registrations TORCHTRT_UNUSED =
398398
nvinfer1::ITensor* multiplier = dim_tensor_list[adv_idx_indices[adv_idx_count - 1]];
399399
nvinfer1::ITensor* cum_adv_index = tensors[adv_idx_count - 1];
400400
for (int i = adv_idx_count - 2; i >= 0; i--) {
401-
nvinfer1::ITensor* adv_index = add_elementwise(
402-
ctx,
403-
nvinfer1::ElementWiseOperation::kPROD,
404-
tensors[i],
405-
multiplier,
406-
std::string("adv_index_") + std::to_string(i))
407-
->getOutput(0);
401+
nvinfer1::ITensor* adv_index =
402+
add_elementwise(
403+
ctx,
404+
nvinfer1::ElementWiseOperation::kPROD,
405+
tensors[i],
406+
multiplier,
407+
util::node_info(n) + std::string("_adv_index_") + std::to_string(i))
408+
->getOutput(0);
408409
cum_adv_index = add_elementwise(
409410
ctx,
410411
nvinfer1::ElementWiseOperation::kSUM,
411412
cum_adv_index,
412413
adv_index,
413-
std::string("cum_adv_index_") + std::to_string(i))
414+
util::node_info(n) + std::string("_cum_adv_index_") + std::to_string(i))
414415
->getOutput(0);
415416
multiplier = add_elementwise(
416417
ctx,
417418
nvinfer1::ElementWiseOperation::kPROD,
418419
multiplier,
419420
dim_tensor_list[adv_idx_indices[i]],
420-
std::string("multiplier_") + std::to_string(i))
421+
util::node_info(n) + std::string("_multiplier_") + std::to_string(i))
421422
->getOutput(0);
422423
}
423424

tests/core/conversion/converters/test_select.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -833,6 +833,38 @@ TEST(Converters, ATenIndexTensorFullIndicesConvertsCorrectly) {
833833
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
834834
}
835835

836+
TEST(Converters, ATenIndexTensorRepeatedFullIndicesConvertsCorrectly) {
837+
const auto graph = R"IR(
838+
graph(%x.1 : Tensor,
839+
%index0 : Tensor,
840+
%index1 : Tensor,
841+
%index2 : Tensor):
842+
%18 : Tensor?[] = prim::ListConstruct(%index0, %index1, %index2)
843+
%19 : Tensor = aten::index(%x.1, %18)
844+
%20 : Tensor = aten::index(%x.1, %18)
845+
return (%19, %20))IR";
846+
847+
auto g = std::make_shared<torch::jit::Graph>();
848+
torch::jit::parseIR(graph, g.get());
849+
850+
auto in1 = at::randint(1, 10, {5, 10, 4}, {at::kCUDA});
851+
auto index0 = at::tensor({0, 1, 2, 3}, {at::kCUDA}).to(torch::kLong);
852+
auto index1 = at::tensor({1, 3, 4, 6}, {at::kCUDA}).to(torch::kLong);
853+
auto index2 = at::tensor({3, 2, 1, 0}, {at::kCUDA}).to(torch::kLong);
854+
auto index0_trt = index0.to(torch::kInt32);
855+
auto index1_trt = index1.to(torch::kInt32);
856+
auto index2_trt = index2.to(torch::kInt32);
857+
858+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
859+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0, index1, index2});
860+
861+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
862+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt, index1_trt, index2_trt});
863+
864+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
865+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[1], trt_results[1], 2e-6));
866+
}
867+
836868
TEST(Converters, ATenIndexTensorIdx0Idx1NoneConvertsCorrectly) {
837869
const auto graph = R"IR(
838870
graph(%x.1 : Tensor,

0 commit comments

Comments
 (0)