Skip to content

Commit a00d1bf

Browse files
authored
Merge pull request #162 from abhi-iyer/LSTMCellConverter
LSTMCell converter
2 parents f1f4fce + 723ac1d commit a00d1bf

File tree

8 files changed

+395
-3
lines changed

8 files changed

+395
-3
lines changed

core/conversion/converters/BUILD

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ cc_library(
4848
"impl/unary.cpp",
4949
"impl/interpolate.cpp",
5050
"impl/select.cpp",
51-
"impl/stack.cpp"
51+
"impl/stack.cpp",
52+
"impl/lstm_cell.cpp"
5253
],
5354
deps = [
5455
"@tensorrt//:nvinfer",
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
#include "torch/torch.h"
2+
#include "NvInfer.h"
3+
#include "core/util/prelude.h"
4+
#include "core/conversion/converters/converters.h"
5+
#include "core/conversion/tensorcontainer/TensorContainer.h"
6+
7+
#include <ATen/ATen.h>
8+
#include <vector>
9+
10+
namespace trtorch {
11+
namespace core {
12+
namespace conversion {
13+
namespace converters {
14+
namespace impl {
15+
namespace {
16+
17+
nvinfer1::ITensor* add_bias(nvinfer1::ITensor* a, nvinfer1::ITensor* b, std::string b_name, ConversionCtx* ctx, const torch::jit::Node* n) {
18+
auto a_dim = a->getDimensions();
19+
auto b_dim = b->getDimensions();
20+
21+
LOG_DEBUG(b_name << " tensor shape: " << b_dim);
22+
23+
TRTORCH_CHECK(util::broadcastable(a_dim, b_dim, false), "bias " << b_name << " is not broadcastable - can't be added to previous matmul operation.");
24+
25+
if (util::toVec(a_dim) != util::toVec(b_dim)) {
26+
LOG_DEBUG(b_name << "'s dimensions need to be reshaped");
27+
28+
auto shuffle = ctx->net->addShuffle(*b);
29+
TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
30+
shuffle->setReshapeDimensions(util::toDimsPad(util::toVec(b_dim), a_dim.nbDims));
31+
32+
b = shuffle->getOutput(0);
33+
}
34+
35+
LOG_DEBUG(b_name << "'s shape: " << b->getDimensions());
36+
37+
auto add = ctx->net->addElementWise(*a, *b, nvinfer1::ElementWiseOperation::kSUM);
38+
TRTORCH_CHECK(add, "Unable to create ElementWise layer from node: " << *n);
39+
40+
return add->getOutput(0);
41+
}
42+
43+
auto lstm_cell_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
44+
.pattern({
45+
"aten::lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> (Tensor, Tensor)",
46+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
47+
auto input = args[0].ITensorOrFreeze(ctx);
48+
auto w_ih = args[2].ITensorOrFreeze(ctx);
49+
auto w_hh = args[3].ITensorOrFreeze(ctx);
50+
51+
LOG_DEBUG("Input tensor shape: " << input->getDimensions());
52+
LOG_DEBUG("w_ih tensor shape: " << w_ih->getDimensions());
53+
LOG_DEBUG("w_hh tensor shape: " << w_hh->getDimensions());
54+
55+
std::vector<nvinfer1::ITensor*> state;
56+
auto hx = args[1].IValue()->toListRef();
57+
for (unsigned int i = 0; i < hx.size(); i++) {
58+
auto t = hx[i];
59+
60+
nvinfer1::ITensor* itensor;
61+
62+
if (t.isTensor()) {
63+
itensor = tensor_to_const(ctx, t.toTensor());
64+
} else {
65+
auto cont = t.toCustomClass<TensorContainer>();
66+
itensor = cont->tensor();
67+
}
68+
69+
LOG_DEBUG("State tensor " << i << " shape: " << itensor->getDimensions());
70+
state.push_back(itensor);
71+
}
72+
73+
// calculate first half of gates
74+
auto mm1 = ctx->net->addMatrixMultiply(*input, nvinfer1::MatrixOperation::kNONE, *w_ih, nvinfer1::MatrixOperation::kTRANSPOSE);
75+
TRTORCH_CHECK(mm1, "Unable to create matrix multiplication node: " << *n);
76+
auto mm1_out = mm1->getOutput(0);
77+
78+
auto out1 = (args[4].isIValue() && args[4].IValue()->isNone()) ? mm1_out : add_bias(mm1_out, args[4].ITensorOrFreeze(ctx), "b_ih", ctx, n);
79+
80+
// calculate second half of gates
81+
auto mm2 = ctx->net->addMatrixMultiply(*state[0], nvinfer1::MatrixOperation::kNONE, *w_hh, nvinfer1::MatrixOperation::kTRANSPOSE);
82+
TRTORCH_CHECK(mm2, "Unable to create matrix multiplication node: " << *n);
83+
auto mm2_out = mm2->getOutput(0);
84+
85+
auto out2 = (args[5].isIValue() && args[5].IValue()->isNone()) ? mm2_out : add_bias(mm2_out, args[5].ITensorOrFreeze(ctx), "b_hh", ctx, n);
86+
87+
// get all 4 gates
88+
auto add = ctx->net->addElementWise(*out1, *out2, nvinfer1::ElementWiseOperation::kSUM);
89+
TRTORCH_CHECK(add, "Unable to create ElementWise layer from node: " << *n);
90+
auto add_out = add->getOutput(0);
91+
92+
// chunk Tensor into 4 parts and apply activation functions
93+
auto dims = util::toVec(add_out->getDimensions());
94+
auto batch = dims[0];
95+
auto hidden = dims[1]/4;
96+
97+
std::vector<int64_t> size_vec = {batch, hidden};
98+
std::vector<int64_t> stride_vec = {1, 1};
99+
std::vector<int64_t> offset0 = {0, 0};
100+
std::vector<int64_t> offset1 = {0, hidden};
101+
std::vector<int64_t> offset2 = {0, 2*hidden};
102+
std::vector<int64_t> offset3 = {0, 3*hidden};
103+
104+
auto size = util::toDims(size_vec);
105+
auto stride = util::toDims(stride_vec);
106+
107+
auto slice1 = ctx->net->addSlice(*add_out, util::toDims(offset0), size, stride);
108+
TRTORCH_CHECK(slice1, "Unable to create Slice layer from node: " << *n);
109+
auto activ1 = ctx->net->addActivation(*slice1->getOutput(0), nvinfer1::ActivationType::kSIGMOID);
110+
TRTORCH_CHECK(activ1, "Unable to create sigmoid activation layer from node: " << *n);
111+
auto ingate = activ1->getOutput(0);
112+
113+
auto slice2 = ctx->net->addSlice(*add_out, util::toDims(offset1), size, stride);
114+
TRTORCH_CHECK(slice2, "Unable to create Slice layer from node: " << *n);
115+
auto activ2 = ctx->net->addActivation(*slice2->getOutput(0), nvinfer1::ActivationType::kSIGMOID);
116+
TRTORCH_CHECK(activ2, "Unable to create sigmoid activation layer from node: " << *n);
117+
auto forgetgate = activ2->getOutput(0);
118+
119+
auto slice3 = ctx->net->addSlice(*add_out, util::toDims(offset2), size, stride);
120+
TRTORCH_CHECK(slice3, "Unable to create Slice layer from node: " << *n);
121+
auto activ3 = ctx->net->addActivation(*slice3->getOutput(0), nvinfer1::ActivationType::kTANH);
122+
TRTORCH_CHECK(activ3, "Unable to create tanh activation layer from node: " << *n);
123+
auto cellgate = activ3->getOutput(0);
124+
125+
auto slice4 = ctx->net->addSlice(*add_out, util::toDims(offset3), size, stride);
126+
TRTORCH_CHECK(slice4, "Unable to create Slice layer from node: " << *n);
127+
auto activ4 = ctx->net->addActivation(*slice4->getOutput(0), nvinfer1::ActivationType::kSIGMOID);
128+
TRTORCH_CHECK(activ4, "Unable to create sigmoid activation layer from node: " << *n);
129+
auto outgate = activ4->getOutput(0);
130+
131+
// compute cy
132+
auto forget_cx = ctx->net->addElementWise(*forgetgate, *state[1], nvinfer1::ElementWiseOperation::kPROD);
133+
TRTORCH_CHECK(forget_cx, "Unable to create ElementWise layer from node: " << *n);
134+
auto in_cell = ctx->net->addElementWise(*ingate, *cellgate, nvinfer1::ElementWiseOperation::kPROD);
135+
TRTORCH_CHECK(in_cell, "Unable to create ElementWise layer from node: " << *n);
136+
auto cy = ctx->net->addElementWise(*forget_cx->getOutput(0), *in_cell->getOutput(0), nvinfer1::ElementWiseOperation::kSUM);
137+
TRTORCH_CHECK(cy, "Unable to create ElementWise layer from node: " << *n);
138+
auto cy_out = cy->getOutput(0);
139+
140+
// compute hy
141+
auto cy_tanh = ctx->net->addActivation(*cy_out, nvinfer1::ActivationType::kTANH);
142+
TRTORCH_CHECK(cy_tanh, "Unable to create tanh activation layer from node: " << *n);
143+
auto hy = ctx->net->addElementWise(*outgate, *cy_tanh->getOutput(0), nvinfer1::ElementWiseOperation::kPROD);
144+
TRTORCH_CHECK(hy, "Unable to create ElementWise layer from node: " << *n);
145+
auto hy_out = hy->getOutput(0);
146+
147+
ctx->AssociateValueAndTensor(n->outputs()[0], hy_out);
148+
ctx->AssociateValueAndTensor(n->outputs()[1], cy_out);
149+
150+
LOG_DEBUG("Output tensor [hy] shape: " << hy_out->getDimensions());
151+
LOG_DEBUG("Output tensor [cy] shape: " << cy_out->getDimensions());
152+
153+
return true;
154+
}
155+
});
156+
} // namespace
157+
} // namespace impl
158+
} // namespace converters
159+
} // namespace conversion
160+
} // namespace core
161+
} // namespace trtorch

core/lowering/lowering.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,10 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
3232
passes::RemoveDropout(g);
3333
passes::FuseFlattenLinear(g);
3434
passes::Conv2DToConvolution(g);
35+
passes::Conv3DToConvolution(g);
3536
passes::FuseAddMMBranches(g);
3637
torch::jit::EliminateCommonSubexpression(g);
37-
torch::jit::UnrollLoops(g);
38+
//torch::jit::UnrollLoops(g);
3839
torch::jit::EliminateCommonSubexpression(g);
3940
passes::UnpackAddMM(g);
4041
//passes::UnpackBatchNorm(g);

core/lowering/passes/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ cc_library(
1414
],
1515
srcs = [
1616
"conv2d_to_convolution.cpp",
17+
"conv3d_to_convolution.cpp",
1718
"exception_elimination.cpp",
1819
"fuse_addmm_branches.cpp",
1920
"fuse_flatten_linear.cpp",
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
2+
3+
#include "core/util/prelude.h"
4+
5+
namespace trtorch {
6+
namespace core {
7+
namespace lowering {
8+
namespace passes {
9+
10+
void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
11+
std::string conv3d_pattern = R"IR(
12+
graph(%x, %w, %b, %s, %p, %d, %g):
13+
%4 : Tensor = aten::conv3d(%x, %w, %b, %s, %p, %d, %g)
14+
return (%4))IR";
15+
std::string convolution_pattern = R"IR(
16+
graph(%x, %w, %b, %s, %p, %d, %g):
17+
%1 : bool = prim::Constant[value=0]()
18+
%2 : int[] = prim::Constant[value=[0, 0]]()
19+
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1)
20+
return (%4))IR";;
21+
22+
// replace matmul + add pattern to linear
23+
torch::jit::SubgraphRewriter map_conv3d_to_convolution;
24+
map_conv3d_to_convolution.RegisterRewritePattern(
25+
conv3d_pattern, convolution_pattern);
26+
map_conv3d_to_convolution.runOnGraph(graph);
27+
LOG_GRAPH("Post map conv3d -> _convolution: " << *graph);
28+
}
29+
30+
} // namespace passes
31+
} // namespace lowering
32+
} // namespace core
33+
} // namespace trtorch

core/lowering/passes/passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ namespace lowering {
88
namespace passes {
99

1010
void Conv2DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
11+
void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
1112
void FuseAddMMBranches(std::shared_ptr<torch::jit::Graph> graph);
1213
void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph);
1314
void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);

tests/core/converters/BUILD

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ converter_test(
6767
name = "test_stack"
6868
)
6969

70+
converter_test(
71+
name = "test_lstm_cell"
72+
)
73+
7074
test_suite(
7175
name = "test_converters",
7276
tests = [
@@ -83,6 +87,7 @@ test_suite(
8387
":test_unary",
8488
":test_interpolate",
8589
":test_select",
86-
":test_stack"
90+
":test_stack",
91+
":test_lstm_cell"
8792
]
8893
)

0 commit comments

Comments
 (0)