Skip to content

Commit 9db2852

Browse files
[feat] Add support for argmax and argmin (#1312)
* [feat] Add support for argmax and argmin Adds support for aten::argmax and aten::argmin. Fixes # (issue) Please delete options that are not relevant and/or add your own. - Bug fix (non-breaking change which fixes an issue) - New feature (non-breaking change which adds functionality) - Breaking change (fix or feature that would cause existing functionality to not work as expected) - This change requires a documentation update - [ ] My code follows the style guidelines of this project (You can use the linters) - [ ] I have performed a self-review of my own code - [ ] I have commented my code, particularly in hard-to-understand areas and hacks - [ ] I have made corresponding changes to the documentation - [ ] I have added tests to verify my fix or my feature - [ ] New and existing unit tests pass locally with my changes - [ ] I have added the relevant labels to my PR in so that relevant reviewers are notified * move max.cpp tests to test_max.cpp no functional change * fix permissions on max.cpp
1 parent 2af5942 commit 9db2852

File tree

4 files changed

+241
-66
lines changed

4 files changed

+241
-66
lines changed

core/conversion/converters/impl/max.cpp

Lines changed: 89 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -13,47 +13,95 @@ namespace conversion {
1313
namespace converters {
1414
namespace impl {
1515
namespace {
16-
auto max_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
17-
{"aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)",
18-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
19-
auto self = args[0].ITensorOrFreeze(ctx);
20-
auto dim = args[1].unwrapToInt();
21-
auto keep_dims = args[2].unwrapToBool();
22-
auto selfDim = util::toVec(self->getDimensions());
23-
if (dim < 0) {
24-
dim = selfDim.size() + dim;
25-
}
26-
uint32_t shiftDim = 1 << dim;
27-
auto TopKOperation = nvinfer1::TopKOperation::kMAX;
28-
auto topk_layer = ctx->net->addTopK(*self, TopKOperation, 1, shiftDim);
29-
TORCHTRT_CHECK(topk_layer, "Unable to create max layer from node: " << *n);
30-
auto topk_dims = util::toVec(topk_layer->getOutput(0)->getDimensions());
31-
32-
nvinfer1::ITensor* out0 = nullptr;
33-
nvinfer1::ITensor* out1 = nullptr;
34-
if (!keep_dims) {
35-
if (topk_dims[dim] == 1) {
36-
auto squeeze_layer = ctx->net->addShuffle(*topk_layer->getOutput(0));
37-
squeeze_layer->setReshapeDimensions(util::squeezeDims(topk_layer->getOutput(0)->getDimensions(), dim));
38-
TORCHTRT_CHECK(squeeze_layer, "Unable to create squeeze_layer layer from node: " << *n);
39-
out0 = ctx->AssociateValueAndTensor(n->outputs()[0], squeeze_layer->getOutput(0));
40-
41-
auto squeeze_layer_indices = ctx->net->addShuffle(*topk_layer->getOutput(1));
42-
squeeze_layer_indices->setReshapeDimensions(
43-
util::squeezeDims(topk_layer->getOutput(1)->getDimensions(), dim));
44-
TORCHTRT_CHECK(squeeze_layer_indices, "Unable to create squeeze_layer_indices layer from node: " << *n);
45-
out1 = ctx->AssociateValueAndTensor(n->outputs()[1], squeeze_layer_indices->getOutput(0));
46-
}
47-
} else {
48-
out0 = ctx->AssociateValueAndTensor(n->outputs()[0], topk_layer->getOutput(0));
49-
out1 = ctx->AssociateValueAndTensor(n->outputs()[1], topk_layer->getOutput(1));
50-
}
51-
52-
LOG_DEBUG("Output tensor(0) shape: " << out0->getDimensions());
53-
LOG_DEBUG("Output tensor(1) shape: " << out1->getDimensions());
54-
55-
return true;
56-
}});
16+
17+
bool min_max_dim(ConversionCtx* ctx, const torch::jit::Node* n, args& args, nvinfer1::TopKOperation topKOperation) {
18+
auto self = args[0].ITensorOrFreeze(ctx);
19+
auto dim = args[1].unwrapToInt();
20+
auto keep_dims = args[2].unwrapToBool();
21+
auto selfDim = util::toVec(self->getDimensions());
22+
if (dim < 0) {
23+
dim = selfDim.size() + dim;
24+
}
25+
uint32_t reduce_axes_mask = 1 << dim;
26+
auto topk_layer = ctx->net->addTopK(*self, topKOperation, 1, reduce_axes_mask);
27+
TORCHTRT_CHECK(topk_layer, "Unable to create topk layer from node: " << *n);
28+
auto topk_dims = util::toVec(topk_layer->getOutput(0)->getDimensions());
29+
30+
nvinfer1::ITensor* out0 = nullptr;
31+
nvinfer1::ITensor* out1 = nullptr;
32+
if (!keep_dims) {
33+
TORCHTRT_CHECK(topk_dims[dim] == 1, "Unexpected size in squeeze dimension. Expected: 1 Actual: " << topk_dims[dim]);
34+
auto squeeze_layer = ctx->net->addShuffle(*topk_layer->getOutput(0));
35+
squeeze_layer->setReshapeDimensions(util::squeezeDims(topk_layer->getOutput(0)->getDimensions(), dim));
36+
TORCHTRT_CHECK(squeeze_layer, "Unable to create squeeze_layer layer from node: " << *n);
37+
out0 = ctx->AssociateValueAndTensor(n->outputs()[0], squeeze_layer->getOutput(0));
38+
39+
auto squeeze_layer_indices = ctx->net->addShuffle(*topk_layer->getOutput(1));
40+
squeeze_layer_indices->setReshapeDimensions(util::squeezeDims(topk_layer->getOutput(1)->getDimensions(), dim));
41+
TORCHTRT_CHECK(squeeze_layer_indices, "Unable to create squeeze_layer_indices layer from node: " << *n);
42+
out1 = ctx->AssociateValueAndTensor(n->outputs()[1], squeeze_layer_indices->getOutput(0));
43+
} else {
44+
out0 = ctx->AssociateValueAndTensor(n->outputs()[0], topk_layer->getOutput(0));
45+
out1 = ctx->AssociateValueAndTensor(n->outputs()[1], topk_layer->getOutput(1));
46+
}
47+
48+
LOG_DEBUG("Output tensor(0) shape: " << out0->getDimensions());
49+
LOG_DEBUG("Output tensor(1) shape: " << out1->getDimensions());
50+
51+
return true;
52+
}
53+
54+
bool arg_min_max(ConversionCtx* ctx, const torch::jit::Node* n, args& args, nvinfer1::TopKOperation topKOperation) {
55+
auto self = args[0].ITensorOrFreeze(ctx);
56+
auto dim = args[1].unwrapToInt();
57+
auto keep_dims = args[2].unwrapToBool();
58+
auto selfDim = util::toVec(self->getDimensions());
59+
if (dim < 0) {
60+
dim = selfDim.size() + dim;
61+
}
62+
uint32_t reduce_axes_mask = 1 << dim;
63+
auto topk_layer = ctx->net->addTopK(*self, topKOperation, 1, reduce_axes_mask);
64+
TORCHTRT_CHECK(topk_layer, "Unable to create topk layer from node: " << *n);
65+
auto topk_dims = util::toVec(topk_layer->getOutput(0)->getDimensions());
66+
67+
nvinfer1::ITensor* out = nullptr;
68+
if (!keep_dims) {
69+
TORCHTRT_CHECK(topk_dims[dim] == 1, "Unexpected size in squeeze dimension. Expected: 1 Actual: " << topk_dims[dim]);
70+
auto squeeze_layer_indices = ctx->net->addShuffle(*topk_layer->getOutput(1));
71+
squeeze_layer_indices->setReshapeDimensions(util::squeezeDims(topk_layer->getOutput(1)->getDimensions(), dim));
72+
TORCHTRT_CHECK(squeeze_layer_indices, "Unable to create squeeze_layer_indices layer from node: " << *n);
73+
out = ctx->AssociateValueAndTensor(n->outputs()[0], squeeze_layer_indices->getOutput(0));
74+
} else {
75+
out = ctx->AssociateValueAndTensor(n->outputs()[0], topk_layer->getOutput(1));
76+
}
77+
78+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
79+
80+
return true;
81+
}
82+
83+
auto max_registrations TORCHTRT_UNUSED =
84+
RegisterNodeConversionPatterns()
85+
.pattern(
86+
{"aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)",
87+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
88+
return min_max_dim(ctx, n, args, nvinfer1::TopKOperation::kMAX);
89+
}})
90+
.pattern(
91+
{"aten::min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)",
92+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
93+
return min_max_dim(ctx, n, args, nvinfer1::TopKOperation::kMIN);
94+
}})
95+
.pattern(
96+
{"aten::argmax(Tensor self, int dim, bool keepdim=False) -> (Tensor)",
97+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
98+
return arg_min_max(ctx, n, args, nvinfer1::TopKOperation::kMAX);
99+
}})
100+
.pattern(
101+
{"aten::argmin(Tensor self, int dim, bool keepdim=False) -> (Tensor)",
102+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
103+
return arg_min_max(ctx, n, args, nvinfer1::TopKOperation::kMIN);
104+
}});
57105
} // namespace
58106
} // namespace impl
59107
} // namespace converters

tests/core/conversion/converters/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ converter_test(
7171
name = "test_matrix_multiply",
7272
)
7373

74+
converter_test(
75+
name = "test_max",
76+
)
77+
7478
converter_test(
7579
name = "test_normalize",
7680
)
@@ -156,6 +160,7 @@ test_suite(
156160
":test_linear",
157161
":test_lstm_cell",
158162
":test_matrix_multiply",
163+
":test_max",
159164
":test_normalize",
160165
":test_pooling",
161166
":test_reduce",
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "gtest/gtest.h"
4+
#include "tests/util/util.h"
5+
#include "torch/csrc/jit/ir/irparser.h"
6+
7+
TEST(Converters, ATenMaxDimConvertsCorrectly) {
8+
const auto graph = R"IR(
9+
graph(%x.1 : Tensor):
10+
%2 : int = prim::Constant[value=0]()
11+
%3 : bool = prim::Constant[value=0]()
12+
%4 : Tensor, %5 : Tensor = aten::max(%x.1, %2, %3)
13+
return (%4, %5))IR";
14+
15+
auto g = std::make_shared<torch::jit::Graph>();
16+
torch::jit::parseIR(graph, g.get());
17+
18+
auto in = at::rand({2, 3, 5, 5}, {at::kCUDA});
19+
20+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
21+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
22+
23+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
24+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
25+
26+
ASSERT_TRUE(
27+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
28+
ASSERT_TRUE(
29+
torch_tensorrt::tests::util::almostEqual(jit_results[1], trt_results[1].reshape_as(jit_results[1]), 2e-6));
30+
}
31+
32+
TEST(Converters, ATenMinDimConvertsCorrectly) {
33+
const auto graph = R"IR(
34+
graph(%x.1 : Tensor):
35+
%2 : int = prim::Constant[value=0]()
36+
%3 : bool = prim::Constant[value=0]()
37+
%4 : Tensor, %5 : Tensor = aten::min(%x.1, %2, %3)
38+
return (%4, %5))IR";
39+
40+
auto g = std::make_shared<torch::jit::Graph>();
41+
torch::jit::parseIR(graph, g.get());
42+
43+
auto in = at::rand({2, 3, 5, 5}, {at::kCUDA});
44+
45+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
46+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
47+
48+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
49+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
50+
51+
ASSERT_TRUE(
52+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
53+
ASSERT_TRUE(
54+
torch_tensorrt::tests::util::almostEqual(jit_results[1], trt_results[1].reshape_as(jit_results[1]), 2e-6));
55+
}
56+
57+
TEST(Converters, ATenArgMaxConvertsCorrectly) {
58+
const auto graph = R"IR(
59+
graph(%x.1 : Tensor):
60+
%2 : int = prim::Constant[value=0]()
61+
%3 : bool = prim::Constant[value=0]()
62+
%4 : Tensor = aten::argmax(%x.1, %2, %3)
63+
return (%4))IR";
64+
65+
auto g = std::make_shared<torch::jit::Graph>();
66+
torch::jit::parseIR(graph, g.get());
67+
68+
auto in = at::rand({2, 3, 5, 5}, {at::kCUDA});
69+
70+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
71+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
72+
73+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
74+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
75+
76+
ASSERT_TRUE(
77+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
78+
}
79+
80+
TEST(Converters, ATenArgMaxKeepdimConvertsCorrectly) {
81+
const auto graph = R"IR(
82+
graph(%x.1 : Tensor):
83+
%2 : int = prim::Constant[value=1]()
84+
%3 : bool = prim::Constant[value=1]()
85+
%4 : Tensor = aten::argmax(%x.1, %2, %3)
86+
return (%4))IR";
87+
88+
auto g = std::make_shared<torch::jit::Graph>();
89+
torch::jit::parseIR(graph, g.get());
90+
91+
auto in = at::rand({2, 3, 5, 5}, {at::kCUDA});
92+
93+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
94+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
95+
96+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
97+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
98+
99+
ASSERT_TRUE(
100+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
101+
}
102+
103+
TEST(Converters, ATenArgMinConvertsCorrectly) {
104+
const auto graph = R"IR(
105+
graph(%x.1 : Tensor):
106+
%2 : int = prim::Constant[value=0]()
107+
%3 : bool = prim::Constant[value=0]()
108+
%4 : Tensor = aten::argmin(%x.1, %2, %3)
109+
return (%4))IR";
110+
111+
auto g = std::make_shared<torch::jit::Graph>();
112+
torch::jit::parseIR(graph, g.get());
113+
114+
auto in = at::rand({2, 3, 5, 5}, {at::kCUDA});
115+
116+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
117+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
118+
119+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
120+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
121+
122+
ASSERT_TRUE(
123+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
124+
}
125+
126+
TEST(Converters, ATenArgMinKeepdimConvertsCorrectly) {
127+
const auto graph = R"IR(
128+
graph(%x.1 : Tensor):
129+
%2 : int = prim::Constant[value=1]()
130+
%3 : bool = prim::Constant[value=1]()
131+
%4 : Tensor = aten::argmin(%x.1, %2, %3)
132+
return (%4))IR";
133+
134+
auto g = std::make_shared<torch::jit::Graph>();
135+
torch::jit::parseIR(graph, g.get());
136+
137+
auto in = at::rand({2, 3, 5, 5}, {at::kCUDA});
138+
139+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
140+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
141+
142+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
143+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
144+
145+
ASSERT_TRUE(
146+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
147+
}

tests/core/conversion/converters/test_topk.cpp

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -30,28 +30,3 @@ TEST(Converters, ATenTopKConvertsCorrectly) {
3030
ASSERT_TRUE(
3131
torch_tensorrt::tests::util::almostEqual(jit_results[1], trt_results[1].reshape_as(jit_results[1]), 2e-6));
3232
}
33-
34-
TEST(Converters, ATenMaxDimConvertsCorrectly) {
35-
const auto graph = R"IR(
36-
graph(%x.1 : Tensor):
37-
%2 : int = prim::Constant[value=0]()
38-
%3 : bool = prim::Constant[value=0]()
39-
%4 : Tensor, %5 : Tensor = aten::max(%x.1, %2, %3)
40-
return (%4, %5))IR";
41-
42-
auto g = std::make_shared<torch::jit::Graph>();
43-
torch::jit::parseIR(graph, g.get());
44-
45-
auto in = at::rand({2, 3, 5, 5}, {at::kCUDA});
46-
47-
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
48-
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
49-
50-
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
51-
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
52-
53-
ASSERT_TRUE(
54-
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
55-
ASSERT_TRUE(
56-
torch_tensorrt::tests::util::almostEqual(jit_results[1], trt_results[1].reshape_as(jit_results[1]), 2e-6));
57-
}

0 commit comments

Comments
 (0)