Skip to content

Commit be69060

Browse files
authored
Merge pull request #116 from abhi-iyer/master
Add support for aten::select.int and aten::stack
2 parents 5127515 + f594e43 commit be69060

File tree

8 files changed

+299
-1
lines changed

8 files changed

+299
-1
lines changed

core/conversion/converters/BUILD

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ cc_library(
2828
"impl/shuffle.cpp",
2929
"impl/softmax.cpp",
3030
"impl/unary.cpp",
31-
"impl/interpolate.cpp"
31+
"impl/interpolate.cpp",
32+
"impl/select.cpp",
33+
"impl/stack.cpp"
3234
],
3335
deps = [
3436
"@tensorrt//:nvinfer",
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
#include "torch/torch.h"
2+
#include "core/util/prelude.h"
3+
#include "core/conversion/converters/converters.h"
4+
#include "NvInfer.h"
5+
6+
#include <ATen/ATen.h>
7+
#include <vector>
8+
9+
namespace trtorch {
10+
namespace core {
11+
namespace conversion {
12+
namespace converters {
13+
namespace impl {
14+
namespace {
15+
16+
auto select_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
17+
.pattern({
18+
"aten::select.int(Tensor(a) self, int dim, int index) -> (Tensor(a))",
19+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
20+
auto in = args[0].ITensor();
21+
auto axis = args[1].unwrapToInt();
22+
auto ind = (int32_t) args[2].unwrapToInt();
23+
24+
// index to access needs to be an at::Tensor
25+
at::Tensor indices = torch::tensor({ind}).to(torch::kI32);
26+
auto weights = Weights(ctx, indices);
27+
28+
// IConstantLayer to convert indices from Weights to ITensor
29+
auto const_layer = ctx->net->addConstant(weights.shape, weights.data);
30+
TRTORCH_CHECK(const_layer, "Unable to create constant layer from node: " << *n);
31+
auto const_out = const_layer->getOutput(0);
32+
33+
// IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices from
34+
auto gather_layer = ctx->net->addGather(*in, *const_out, axis);
35+
TRTORCH_CHECK(gather_layer, "Unable to create gather layer from node: " << *n);
36+
auto gather_out = gather_layer->getOutput(0);
37+
38+
// IShuffleLayer removes redundant dimensions
39+
auto shuffle_layer = ctx->net->addShuffle(*gather_out);
40+
TRTORCH_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n);
41+
shuffle_layer->setReshapeDimensions(util::unpadDims(gather_out->getDimensions()));
42+
shuffle_layer->setName(util::node_info(n).c_str());
43+
auto shuffle_out = shuffle_layer->getOutput(0);
44+
45+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle_out);
46+
47+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
48+
49+
return true;
50+
}
51+
});
52+
53+
} // namespace
54+
} // namespace impl
55+
} // namespace converters
56+
} // namespace conversion
57+
} // namespace core
58+
} // namespace trtorch
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#include "torch/torch.h"
2+
#include "core/util/prelude.h"
3+
#include "core/conversion/converters/converters.h"
4+
#include "core/conversion/tensorcontainer/TensorContainer.h"
5+
#include "NvInfer.h"
6+
7+
#include <ATen/ATen.h>
8+
#include <vector>
9+
10+
namespace trtorch {
11+
namespace core {
12+
namespace conversion {
13+
namespace converters {
14+
namespace impl {
15+
namespace {
16+
17+
auto stack_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
18+
.pattern({
19+
"aten::stack(Tensor[] tensors, int dim=0) -> (Tensor)",
20+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
21+
auto in = args[0].IValue()->toListRef();
22+
auto dim = args[1].unwrapToInt();
23+
24+
std::vector<nvinfer1::ITensor*> tensors;
25+
26+
for (auto t : in) {
27+
nvinfer1::ITensor* itensor;
28+
29+
if (t.isTensor()) {
30+
auto weight = Weights(ctx, t.toTensor());
31+
32+
auto const_layer = ctx->net->addConstant(weight.shape, weight.data);
33+
TRTORCH_CHECK(const_layer, "Unable to create constant layer from node: " << *n);
34+
35+
itensor = const_layer->getOutput(0);
36+
} else {
37+
auto cont = t.toCustomClass<TensorContainer>();
38+
itensor = cont->tensor();
39+
}
40+
41+
auto shuffle_layer = ctx->net->addShuffle(*itensor);
42+
TRTORCH_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n);
43+
shuffle_layer->setReshapeDimensions(util::unsqueezeDims(itensor->getDimensions(), dim));
44+
45+
tensors.push_back(shuffle_layer->getOutput(0));
46+
}
47+
48+
auto concat_layer = ctx->net->addConcatenation(tensors.data(), tensors.size());
49+
TRTORCH_CHECK(concat_layer, "Unable to create concatenation layer from node: " << *n);
50+
concat_layer->setAxis(static_cast<int>(dim));
51+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], concat_layer->getOutput(0));
52+
53+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
54+
55+
return true;
56+
}
57+
});
58+
59+
} // namespace
60+
} // namespace impl
61+
} // namespace converters
62+
} // namespace conversion
63+
} // namespace core
64+
} // namespace trtorch

core/util/trt_util.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,56 @@ nvinfer1::Dims toDimsPad(c10::List<int64_t> l, uint64_t pad_to) {
8282
return dims;
8383
}
8484

85+
nvinfer1::Dims unpadDims(const nvinfer1::Dims& d) {
86+
nvinfer1::Dims dims;
87+
88+
int j = 0;
89+
bool pad_dims_done = false;
90+
91+
for (int i = 0; i < d.nbDims; i++) {
92+
if (d.d[i] == 1 && !pad_dims_done) {
93+
// skip over unecessary dimension
94+
continue;
95+
} else {
96+
dims.d[j] = d.d[i];
97+
j++;
98+
99+
// keep all other dimensions (don't skip over them)
100+
pad_dims_done = true;
101+
}
102+
}
103+
104+
dims.nbDims = j;
105+
106+
return dims;
107+
}
108+
109+
nvinfer1::Dims unsqueezeDims(const nvinfer1::Dims& d, int pos) {
110+
// acceptable range for pos is [0, d.nbDims]
111+
TRTORCH_ASSERT(pos >= 0 && pos <= d.nbDims, "ERROR: Index to unsqueeze is out of bounds.");
112+
113+
nvinfer1::Dims dims;
114+
115+
int i = 0;
116+
int j = 0;
117+
118+
while (i <= d.nbDims) {
119+
if (j != pos) {
120+
dims.d[j] = d.d[i];
121+
i++;
122+
} else {
123+
// add new dimension at pos
124+
dims.d[j] = 1;
125+
}
126+
127+
j++;
128+
}
129+
130+
dims.nbDims = d.nbDims+1;
131+
132+
return dims;
133+
}
134+
85135
std::vector<int64_t> toVec(nvinfer1::Dims d) {
86136
std::vector<int64_t> dims;
87137
for (int i = 0; i < d.nbDims; i++) {

core/util/trt_util.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ int64_t volume(const nvinfer1::Dims& d);
7979

8080
nvinfer1::Dims toDimsPad(c10::IntArrayRef l, uint64_t pad_to);
8181
nvinfer1::Dims toDimsPad(c10::List<int64_t> l, uint64_t pad_to);
82+
nvinfer1::Dims unpadDims(const nvinfer1::Dims& d);
83+
nvinfer1::Dims unsqueezeDims(const nvinfer1::Dims& d, int pos);
8284
nvinfer1::Dims toDims(c10::IntArrayRef l);
8385
nvinfer1::Dims toDims(c10::List<int64_t> l);
8486
nvinfer1::DimsHW toDimsHW(c10::List<int64_t> l);

tests/core/converters/BUILD

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,14 @@ converter_test(
5959
name = "test_interpolate"
6060
)
6161

62+
converter_test(
63+
name = "test_select"
64+
)
65+
66+
converter_test(
67+
name = "test_stack"
68+
)
69+
6270
test_suite(
6371
name = "test_converters",
6472
tests = [
@@ -74,6 +82,8 @@ test_suite(
7482
":test_softmax",
7583
":test_unary",
7684
":test_interpolate",
85+
":test_select",
86+
":test_stack"
7787
]
7888
)
7989

tests/core/converters/test_select.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#include <string>
2+
#include "gtest/gtest.h"
3+
#include "torch/csrc/jit/ir/irparser.h"
4+
#include "tests/util/util.h"
5+
#include "core/compiler.h"
6+
7+
TEST(Converters, ATenSelectIntConvertsCorrectly) {
8+
const auto graph = R"IR(
9+
graph(%0 : Tensor):
10+
%2 : int = prim::Constant[value=0]()
11+
%3 : Tensor = aten::select(%0, %2, %2)
12+
return (%3))IR";
13+
14+
auto g = std::make_shared<torch::jit::Graph>();
15+
16+
torch::jit::parseIR(graph, &*g);
17+
18+
auto in = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
19+
20+
auto jit_in = at::clone(in);
21+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
22+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
23+
24+
auto trt_in = at::clone(in);
25+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
26+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
27+
28+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
29+
30+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
31+
}
32+
33+
TEST(Converters, ATenSelectIntTwiceConvertsCorrectly) {
34+
const auto graph = R"IR(
35+
graph(%0 : Tensor):
36+
%2 : int = prim::Constant[value=0]()
37+
%3 : int = prim::Constant[value=3]()
38+
%4 : Tensor = aten::select(%0, %2, %2)
39+
%5 : Tensor = aten::select(%4, %2, %3)
40+
return (%5))IR";
41+
42+
auto g = std::make_shared<torch::jit::Graph>();
43+
44+
torch::jit::parseIR(graph, &*g);
45+
46+
auto in = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
47+
48+
auto jit_in = at::clone(in);
49+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
50+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
51+
52+
auto trt_in = at::clone(in);
53+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
54+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
55+
56+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
57+
58+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
59+
}

tests/core/converters/test_stack.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#include <string>
2+
#include "gtest/gtest.h"
3+
#include "torch/csrc/jit/ir/irparser.h"
4+
#include "tests/util/util.h"
5+
#include "core/compiler.h"
6+
7+
TEST(Converters, ATenStackPureTensorConvertsCorrectly) {
8+
const auto graph = R"IR(
9+
graph(%0 : Tensor,
10+
%1 : Tensor):
11+
%2 : Tensor[] = prim::ListConstruct(%0, %1)
12+
%3 : int = prim::Constant[value=3]()
13+
%4 : Tensor = aten::stack(%2, %3)
14+
return (%4))IR";
15+
16+
auto g = std::make_shared<torch::jit::Graph>();
17+
torch::jit::parseIR(graph, &*g);
18+
19+
auto in1 = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
20+
auto in2 = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
21+
22+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
23+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in1, in2});
24+
25+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
26+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in1, in2});
27+
28+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
29+
}
30+
31+
TEST(Converters, ATenStackDiffTensorConvertsCorrectly) {
32+
const auto graph = R"IR(
33+
graph(%0 : Tensor,
34+
%1 : Float(4, 4, 4)):
35+
%2 : Tensor[] = prim::ListConstruct(%0, %1)
36+
%3 : int = prim::Constant[value=1]()
37+
%4 : Tensor = aten::stack(%2, %3)
38+
return (%4))IR";
39+
40+
auto g = std::make_shared<torch::jit::Graph>();
41+
torch::jit::parseIR(graph, &*g);
42+
43+
auto in1 = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
44+
auto in2 = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
45+
46+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {in2});
47+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in1});
48+
49+
params = trtorch::core::conversion::get_named_params(g->inputs(), {in2});
50+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in1});
51+
52+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
53+
}

0 commit comments

Comments
 (0)