1
+ #include " torch/torch.h"
2
+ #include " NvInfer.h"
3
+ #include " core/util/prelude.h"
4
+ #include " core/conversion/converters/converters.h"
5
+ #include " core/conversion/tensorcontainer/TensorContainer.h"
6
+
7
+ #include < ATen/ATen.h>
8
+ #include < vector>
9
+
10
+ namespace trtorch {
11
+ namespace core {
12
+ namespace conversion {
13
+ namespace converters {
14
+ namespace impl {
15
+ namespace {
16
+
17
+ nvinfer1::ITensor* add_bias (nvinfer1::ITensor* a, nvinfer1::ITensor* b, std::string b_name, ConversionCtx* ctx, const torch::jit::Node* n) {
18
+ auto a_dim = a->getDimensions ();
19
+ auto b_dim = b->getDimensions ();
20
+
21
+ LOG_DEBUG (b_name << " tensor shape: " << b_dim);
22
+
23
+ TRTORCH_CHECK (util::broadcastable (a_dim, b_dim, false ), " bias " << b_name << " is not broadcastable - can't be added to previous matmul operation." );
24
+
25
+ if (util::toVec (a_dim) != util::toVec (b_dim)) {
26
+ LOG_DEBUG (b_name << " 's dimensions need to be reshaped" );
27
+
28
+ auto shuffle = ctx->net ->addShuffle (*b);
29
+ TRTORCH_CHECK (shuffle, " Unable to create shuffle layer from node: " << *n);
30
+ shuffle->setReshapeDimensions (util::toDimsPad (util::toVec (b_dim), a_dim.nbDims ));
31
+
32
+ b = shuffle->getOutput (0 );
33
+ }
34
+
35
+ LOG_DEBUG (b_name << " 's shape: " << b->getDimensions ());
36
+
37
+ auto add = ctx->net ->addElementWise (*a, *b, nvinfer1::ElementWiseOperation::kSUM );
38
+ TRTORCH_CHECK (add, " Unable to create ElementWise layer from node: " << *n);
39
+
40
+ return add->getOutput (0 );
41
+ }
42
+
43
+ auto lstm_cell_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
44
+ .pattern({
45
+ " aten::lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> (Tensor, Tensor)" ,
46
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
47
+ auto input = args[0 ].ITensorOrFreeze (ctx);
48
+ auto w_ih = args[2 ].ITensorOrFreeze (ctx);
49
+ auto w_hh = args[3 ].ITensorOrFreeze (ctx);
50
+
51
+ LOG_DEBUG (" Input tensor shape: " << input->getDimensions ());
52
+ LOG_DEBUG (" w_ih tensor shape: " << w_ih->getDimensions ());
53
+ LOG_DEBUG (" w_hh tensor shape: " << w_hh->getDimensions ());
54
+
55
+ std::vector<nvinfer1::ITensor*> state;
56
+ auto hx = args[1 ].IValue ()->toListRef ();
57
+ for (unsigned int i = 0 ; i < hx.size (); i++) {
58
+ auto t = hx[i];
59
+
60
+ nvinfer1::ITensor* itensor;
61
+
62
+ if (t.isTensor ()) {
63
+ itensor = tensor_to_const (ctx, t.toTensor ());
64
+ } else {
65
+ auto cont = t.toCustomClass <TensorContainer>();
66
+ itensor = cont->tensor ();
67
+ }
68
+
69
+ LOG_DEBUG (" State tensor " << i << " shape: " << itensor->getDimensions ());
70
+ state.push_back (itensor);
71
+ }
72
+
73
+ // calculate first half of gates
74
+ auto mm1 = ctx->net ->addMatrixMultiply (*input, nvinfer1::MatrixOperation::kNONE , *w_ih, nvinfer1::MatrixOperation::kTRANSPOSE );
75
+ TRTORCH_CHECK (mm1, " Unable to create matrix multiplication node: " << *n);
76
+ auto mm1_out = mm1->getOutput (0 );
77
+
78
+ auto out1 = (args[4 ].isIValue () && args[4 ].IValue ()->isNone ()) ? mm1_out : add_bias (mm1_out, args[4 ].ITensorOrFreeze (ctx), " b_ih" , ctx, n);
79
+
80
+ // calculate second half of gates
81
+ auto mm2 = ctx->net ->addMatrixMultiply (*state[0 ], nvinfer1::MatrixOperation::kNONE , *w_hh, nvinfer1::MatrixOperation::kTRANSPOSE );
82
+ TRTORCH_CHECK (mm2, " Unable to create matrix multiplication node: " << *n);
83
+ auto mm2_out = mm2->getOutput (0 );
84
+
85
+ auto out2 = (args[5 ].isIValue () && args[5 ].IValue ()->isNone ()) ? mm2_out : add_bias (mm2_out, args[5 ].ITensorOrFreeze (ctx), " b_hh" , ctx, n);
86
+
87
+ // get all 4 gates
88
+ auto add = ctx->net ->addElementWise (*out1, *out2, nvinfer1::ElementWiseOperation::kSUM );
89
+ TRTORCH_CHECK (add, " Unable to create ElementWise layer from node: " << *n);
90
+ auto add_out = add->getOutput (0 );
91
+
92
+ // chunk Tensor into 4 parts and apply activation functions
93
+ auto dims = util::toVec (add_out->getDimensions ());
94
+ auto batch = dims[0 ];
95
+ auto hidden = dims[1 ]/4 ;
96
+
97
+ std::vector<int64_t > size_vec = {batch, hidden};
98
+ std::vector<int64_t > stride_vec = {1 , 1 };
99
+ std::vector<int64_t > offset0 = {0 , 0 };
100
+ std::vector<int64_t > offset1 = {0 , hidden};
101
+ std::vector<int64_t > offset2 = {0 , 2 *hidden};
102
+ std::vector<int64_t > offset3 = {0 , 3 *hidden};
103
+
104
+ auto size = util::toDims (size_vec);
105
+ auto stride = util::toDims (stride_vec);
106
+
107
+ auto slice1 = ctx->net ->addSlice (*add_out, util::toDims (offset0), size, stride);
108
+ TRTORCH_CHECK (slice1, " Unable to create Slice layer from node: " << *n);
109
+ auto activ1 = ctx->net ->addActivation (*slice1->getOutput (0 ), nvinfer1::ActivationType::kSIGMOID );
110
+ TRTORCH_CHECK (activ1, " Unable to create sigmoid activation layer from node: " << *n);
111
+ auto ingate = activ1->getOutput (0 );
112
+
113
+ auto slice2 = ctx->net ->addSlice (*add_out, util::toDims (offset1), size, stride);
114
+ TRTORCH_CHECK (slice2, " Unable to create Slice layer from node: " << *n);
115
+ auto activ2 = ctx->net ->addActivation (*slice2->getOutput (0 ), nvinfer1::ActivationType::kSIGMOID );
116
+ TRTORCH_CHECK (activ2, " Unable to create sigmoid activation layer from node: " << *n);
117
+ auto forgetgate = activ2->getOutput (0 );
118
+
119
+ auto slice3 = ctx->net ->addSlice (*add_out, util::toDims (offset2), size, stride);
120
+ TRTORCH_CHECK (slice3, " Unable to create Slice layer from node: " << *n);
121
+ auto activ3 = ctx->net ->addActivation (*slice3->getOutput (0 ), nvinfer1::ActivationType::kTANH );
122
+ TRTORCH_CHECK (activ3, " Unable to create tanh activation layer from node: " << *n);
123
+ auto cellgate = activ3->getOutput (0 );
124
+
125
+ auto slice4 = ctx->net ->addSlice (*add_out, util::toDims (offset3), size, stride);
126
+ TRTORCH_CHECK (slice4, " Unable to create Slice layer from node: " << *n);
127
+ auto activ4 = ctx->net ->addActivation (*slice4->getOutput (0 ), nvinfer1::ActivationType::kSIGMOID );
128
+ TRTORCH_CHECK (activ4, " Unable to create sigmoid activation layer from node: " << *n);
129
+ auto outgate = activ4->getOutput (0 );
130
+
131
+ // compute cy
132
+ auto forget_cx = ctx->net ->addElementWise (*forgetgate, *state[1 ], nvinfer1::ElementWiseOperation::kPROD );
133
+ TRTORCH_CHECK (forget_cx, " Unable to create ElementWise layer from node: " << *n);
134
+ auto in_cell = ctx->net ->addElementWise (*ingate, *cellgate, nvinfer1::ElementWiseOperation::kPROD );
135
+ TRTORCH_CHECK (in_cell, " Unable to create ElementWise layer from node: " << *n);
136
+ auto cy = ctx->net ->addElementWise (*forget_cx->getOutput (0 ), *in_cell->getOutput (0 ), nvinfer1::ElementWiseOperation::kSUM );
137
+ TRTORCH_CHECK (cy, " Unable to create ElementWise layer from node: " << *n);
138
+ auto cy_out = cy->getOutput (0 );
139
+
140
+ // compute hy
141
+ auto cy_tanh = ctx->net ->addActivation (*cy_out, nvinfer1::ActivationType::kTANH );
142
+ TRTORCH_CHECK (cy_tanh, " Unable to create tanh activation layer from node: " << *n);
143
+ auto hy = ctx->net ->addElementWise (*outgate, *cy_tanh->getOutput (0 ), nvinfer1::ElementWiseOperation::kPROD );
144
+ TRTORCH_CHECK (hy, " Unable to create ElementWise layer from node: " << *n);
145
+ auto hy_out = hy->getOutput (0 );
146
+
147
+ ctx->AssociateValueAndTensor (n->outputs ()[0 ], hy_out);
148
+ ctx->AssociateValueAndTensor (n->outputs ()[1 ], cy_out);
149
+
150
+ LOG_DEBUG (" Output tensor [hy] shape: " << hy_out->getDimensions ());
151
+ LOG_DEBUG (" Output tensor [cy] shape: " << cy_out->getDimensions ());
152
+
153
+ return true ;
154
+ }
155
+ });
156
+ } // namespace
157
+ } // namespace impl
158
+ } // namespace converters
159
+ } // namespace conversion
160
+ } // namespace core
161
+ } // namespace trtorch
0 commit comments