Skip to content

Commit 3decf45

Browse files
authored
Merge pull request #1491 from mfeliz-cruise/michael.feliz/1d_topk
[feat] Support 1D topk
2 parents e776efb + dacf483 commit 3decf45

File tree

4 files changed

+60
-8
lines changed

4 files changed

+60
-8
lines changed

core/conversion/converters/converter_util.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ nvinfer1::ITensor* addPadding(
1313
nvinfer1::ITensor* tensor,
1414
int nDim,
1515
bool trailing,
16-
bool use_zeros) {
16+
bool use_zeros,
17+
const std::string& name) {
1718
const auto dims = tensor->getDimensions();
1819

1920
if (dims.nbDims < nDim) {
@@ -27,7 +28,11 @@ nvinfer1::ITensor* addPadding(
2728
TORCHTRT_CHECK(shuffle_layer, "Unable to create shuffle layer");
2829
shuffle_layer->setReshapeDimensions(newDims);
2930
shuffle_layer->setZeroIsPlaceholder(use_zeros);
30-
shuffle_layer->setName((util::node_info(n) + " [Reshape to " + util::toStr(newDims) + ']').c_str());
31+
if (name.size()) {
32+
shuffle_layer->setName(name.c_str());
33+
} else {
34+
shuffle_layer->setName((util::node_info(n) + " [Reshape to " + util::toStr(newDims) + ']').c_str());
35+
}
3136
return shuffle_layer->getOutput(0);
3237
} else {
3338
return tensor;
@@ -40,7 +45,8 @@ nvinfer1::ITensor* addUnpadding(
4045
nvinfer1::ITensor* tensor,
4146
int nDim,
4247
bool trailing,
43-
bool use_zeros) {
48+
bool use_zeros,
49+
const std::string& name) {
4450
const auto dims = tensor->getDimensions();
4551
if (dims.nbDims > nDim) {
4652
auto newDims = dims;
@@ -52,7 +58,11 @@ nvinfer1::ITensor* addUnpadding(
5258
TORCHTRT_CHECK(shuffle_layer, "Unable to create shuffle layer");
5359
shuffle_layer->setReshapeDimensions(newDims);
5460
shuffle_layer->setZeroIsPlaceholder(use_zeros);
55-
shuffle_layer->setName((util::node_info(n) + " [Reshape to " + util::toStr(newDims) + "]").c_str());
61+
if (name.size()) {
62+
shuffle_layer->setName(name.c_str());
63+
} else {
64+
shuffle_layer->setName((util::node_info(n) + " [Reshape to " + util::toStr(newDims) + ']').c_str());
65+
}
5666
return shuffle_layer->getOutput(0);
5767
} else {
5868
return tensor;

core/conversion/converters/converter_util.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ nvinfer1::ITensor* addPadding(
2222
nvinfer1::ITensor* tensor,
2323
int nDim,
2424
bool trailing = true,
25-
bool use_zeros = true);
25+
bool use_zeros = true,
26+
const std::string& name = "");
2627

2728
// If nDim < tensor size, adds shuffle layer to un-pad tensor (at the end if trailing) and returns (nDim-dimensional)
2829
// shuffle layer's output Otherwise, does nothing and passes tensor through. use _zeros controls whether we should be
@@ -33,7 +34,8 @@ nvinfer1::ITensor* addUnpadding(
3334
nvinfer1::ITensor* tensor,
3435
int nDim,
3536
bool trailing = true,
36-
bool use_zeros = true);
37+
bool use_zeros = true,
38+
const std::string& name = "");
3739

3840
// TODO: Change add_elementwise schema to output nvinfer1::ITensor* instead of nvinfer1::ILayer*,
3941
// for consistency with other utils. Need to change schema and usage in all calling contexts

core/conversion/converters/impl/topk.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,29 @@ auto topk_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().patte
3939

4040
LOG_DEBUG("Output topk reduce dim: " << dim);
4141

42+
// The topk layer requires at least 2 input dimensions
43+
auto nbDims = self->getDimensions().nbDims;
44+
if (nbDims == 1) {
45+
self = addPadding(ctx, n, self, 2, true, true);
46+
}
47+
4248
auto TopKOperation = largest ? (nvinfer1::TopKOperation::kMAX) : (nvinfer1::TopKOperation::kMIN);
4349

4450
auto new_layer = ctx->net->addTopK(*self, TopKOperation, k, shiftDim);
4551

4652
TORCHTRT_CHECK(new_layer, "Unable to create topk layer from node: " << *n);
4753

48-
auto out0 = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
49-
auto out1 = ctx->AssociateValueAndTensor(n->outputs()[1], new_layer->getOutput(1));
54+
auto values = new_layer->getOutput(0);
55+
auto indices = new_layer->getOutput(1);
56+
57+
// If we expanded the input, squeeze the outputs
58+
if (nbDims == 1) {
59+
values = addUnpadding(ctx, n, values, 1, true, true, util::node_info(n) + "_squeeze_values");
60+
indices = addUnpadding(ctx, n, indices, 1, true, true, util::node_info(n) + "_squeeze_indices");
61+
}
5062

63+
auto out0 = ctx->AssociateValueAndTensor(n->outputs()[0], values);
64+
auto out1 = ctx->AssociateValueAndTensor(n->outputs()[1], indices);
5165
LOG_DEBUG("Output tensor(0) shape: " << out0->getDimensions());
5266
LOG_DEBUG("Output tensor(1) shape: " << out1->getDimensions());
5367

tests/core/conversion/converters/test_topk.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,29 @@ TEST(Converters, ATenTopKConvertsCorrectly) {
3030
ASSERT_TRUE(
3131
torch_tensorrt::tests::util::almostEqual(jit_results[1], trt_results[1].reshape_as(jit_results[1]), 8e-5));
3232
}
33+
34+
TEST(Converters, ATen1DTopKConvertsCorrectly) {
35+
const auto graph = R"IR(
36+
graph(%0 : Tensor):
37+
%1 : int = prim::Constant[value=20]()
38+
%2 : int = prim::Constant[value=-1]()
39+
%3 : bool = prim::Constant[value=1]()
40+
%4 : bool = prim::Constant[value=1]()
41+
%5 : Tensor, %6 : Tensor = aten::topk(%0, %1, %2, %3, %4)
42+
return (%5, %6))IR";
43+
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
44+
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
45+
auto g = std::make_shared<torch::jit::Graph>();
46+
torch::jit::parseIR(graph, g.get());
47+
48+
auto in = at::rand({100}, {at::kCUDA});
49+
50+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
51+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
52+
53+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
54+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
55+
56+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 8e-5));
57+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[1], trt_results[1], 8e-5));
58+
}

0 commit comments

Comments
 (0)