@@ -24,7 +24,8 @@ nvinfer1::ILayer* add_elementwise(
24
24
auto selfDim = util::toVec (self->getDimensions ());
25
25
auto otherDim = util::toVec (other->getDimensions ());
26
26
if (selfDim.size () != otherDim.size ()) {
27
- // other is with dynamic shape, need to expand its dimension now and get its shape at runtime
27
+ // other is with dynamic shape, need to expand its dimension now and get its
28
+ // shape at runtime
28
29
if (otherDim.end () != std::find (otherDim.begin (), otherDim.end (), -1 )) {
29
30
auto thOtherStaticShapeMask = torch::ones (selfDim.size (), torch::kInt32 );
30
31
auto thOtherDynamicShapeMask = torch::zeros (selfDim.size (), torch::kInt32 );
@@ -39,7 +40,8 @@ nvinfer1::ILayer* add_elementwise(
39
40
auto otherStaticShapeMask = tensor_to_const (ctx, thOtherStaticShapeMask);
40
41
auto otherDynamicShapeMask = tensor_to_const (ctx, thOtherDynamicShapeMask);
41
42
auto selfShape = ctx->net ->addShape (*self)->getOutput (0 );
42
- // size of dynamic dimension of other need to the same as that of corresponding dimension of self
43
+ // size of dynamic dimension of other need to the same as that of
44
+ // corresponding dimension of self
43
45
auto otherDynamicShape =
44
46
ctx->net ->addElementWise (*selfShape, *otherDynamicShapeMask, nvinfer1::ElementWiseOperation::kPROD )
45
47
->getOutput (0 );
@@ -52,7 +54,8 @@ nvinfer1::ILayer* add_elementwise(
52
54
otherShuffle->setInput (1 , *targetOtherShape);
53
55
other = otherShuffle->getOutput (0 );
54
56
} else {
55
- // other is with static shape, expand dimension to make tow tensor have the same number of dimension
57
+ // other is with static shape, expand dimension to make tow tensor have
58
+ // the same number of dimension
56
59
auto otherShuffle = ctx->net ->addShuffle (*other);
57
60
otherShuffle->setReshapeDimensions (util::toDimsPad (otherDim, selfDim.size ()));
58
61
other = otherShuffle->getOutput (0 );
@@ -68,28 +71,26 @@ nvinfer1::ILayer* add_elementwise(
68
71
return ele;
69
72
}
70
73
71
-
72
74
auto layer_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern({
73
75
R"SIG( aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? gamma, Tensor? beta,
74
76
float eps, bool cudnn_enabled) -> (Tensor))SIG" ,
75
77
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
76
-
77
78
auto input = args[0 ].ITensor (); // assumes non-static input Tensor
78
79
auto orig_shape = input->getDimensions ();
79
80
auto shape = util::toVec (orig_shape);
80
81
81
82
/* Layer_Norm normalizes over last N dimensions.
82
83
normalizaed_shape could be (C,H,W), (H,W), or (W). */
83
-
84
+
84
85
auto normalized_shape = args[1 ].unwrapToIntList ();
85
86
auto normalized_shape_vec = util::toVec (util::toDims (normalized_shape));
86
87
87
-
88
88
torch::Tensor gamma, beta;
89
89
gamma = args[2 ].unwrapToTensor ();
90
90
beta = args[3 ].unwrapToTensor ();
91
91
92
- // Remove batch dimension from input shape for expand_size, which will be used to create weights for addScaleNd later.
92
+ // Remove batch dimension from input shape for expand_size, which will
93
+ // be used to create weights for addScaleNd later.
93
94
auto expand_size = shape;
94
95
expand_size.erase (expand_size.begin (), expand_size.begin () + 1 );
95
96
auto gamma_expand = gamma.expand (expand_size);
@@ -119,15 +120,15 @@ auto layer_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().
119
120
repeats_expected.push_back (repeat);
120
121
}
121
122
122
- int repeats_expected_rank = repeats_expected.size (); // 4
123
- auto mean_layer_expected_out_dims = mean_layer_expected_out->getDimensions (); // 1
124
- auto num_expand_dims_expected = repeats_expected_rank - mean_layer_expected_out_dims.nbDims ; // 3
125
-
123
+ int repeats_expected_rank = repeats_expected.size ();
124
+ auto mean_layer_expected_out_dims = mean_layer_expected_out->getDimensions ();
125
+ auto num_expand_dims_expected = repeats_expected_rank - mean_layer_expected_out_dims.nbDims ;
126
+
126
127
if (num_expand_dims_expected > 0 ) {
127
128
nvinfer1::Dims reshape_expected_dims;
128
129
reshape_expected_dims.nbDims = repeats_expected.size ();
129
130
for (int i = 0 ; i < num_expand_dims_expected; i++) {
130
- reshape_expected_dims.d [repeats_expected.size () - 1 - i ] = 1 ;
131
+ reshape_expected_dims.d [repeats_expected.size () - 1 - i] = 1 ;
131
132
}
132
133
for (int i = 0 ; i < mean_layer_expected_out_dims.nbDims ; i++) {
133
134
reshape_expected_dims.d [i] = mean_layer_expected_out_dims.d [i];
@@ -147,44 +148,49 @@ auto layer_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().
147
148
concat_layer->setAxis (i);
148
149
mean_layer_expected_out = concat_layer->getOutput (0 );
149
150
}
150
-
151
151
152
152
// X-E[x]
153
- auto sub = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kSUB , input, mean_layer_expected_out, (util::node_info (n) + " _sub" ).c_str ());
153
+ auto sub = add_elementwise (
154
+ ctx,
155
+ nvinfer1::ElementWiseOperation::kSUB ,
156
+ input,
157
+ mean_layer_expected_out,
158
+ (util::node_info (n) + " _sub" ).c_str ());
154
159
TRTORCH_CHECK (sub, " Unable to create Add layer from node: " << *n);
155
160
sub->setName ((util::node_info (n) + " _sub" ).c_str ());
156
161
auto xsubmean = sub->getOutput (0 );
157
162
158
163
// Variance
159
164
float pow_scalar = 2 ;
160
165
auto exponent = tensor_to_const (ctx, torch::tensor ({pow_scalar}));
161
- auto pow = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kPOW , xsubmean, exponent, (util::node_info (n) +" _pow" ).c_str ());
166
+ auto pow = add_elementwise (
167
+ ctx, nvinfer1::ElementWiseOperation::kPOW , xsubmean, exponent, (util::node_info (n) + " _pow" ).c_str ());
162
168
TRTORCH_CHECK (pow, " Unable to create Power layer from node: " << *n);
163
- pow->setName ((util::node_info (n) +" _pow" ).c_str ());
169
+ pow->setName ((util::node_info (n) + " _pow" ).c_str ());
164
170
auto pow_out = pow->getOutput (0 );
165
171
166
172
auto mean_layer_var = ctx->net ->addReduce (*pow_out, nvinfer1::ReduceOperation::kAVG , axis_mask, false );
167
173
TRTORCH_CHECK (mean_layer_var, " Unable to create mean_layer_var from node: " << *n);
168
174
mean_layer_var->setName ((util::node_info (n) + " _mean_var" ).c_str ());
169
175
auto mean_layer_var_out = mean_layer_var->getOutput (0 );
170
-
171
- // Expand output of mean_layer_var to the same shape as original input.
176
+
177
+ // Expand output of mean_layer_var to the same shape as original
178
+ // input.
172
179
c10::List<int64_t > repeats_var;
173
180
for (size_t i = 0 ; i < shape.size (); i++) {
174
- auto repeat = i > (shape.size ()- normalized_shape_vec.size ()- 1 ) ? shape[i] : 1 ;
181
+ auto repeat = i > (shape.size () - normalized_shape_vec.size () - 1 ) ? shape[i] : 1 ;
175
182
repeats_var.push_back (repeat);
176
183
}
177
184
178
185
int repeats_var_rank = repeats_var.size ();
179
- auto mean_layer_var_out_dims = mean_layer_var_out->getDimensions ();
186
+ auto mean_layer_var_out_dims = mean_layer_var_out->getDimensions ();
180
187
auto num_expand_dims_var = repeats_var_rank - mean_layer_var_out_dims.nbDims ;
181
-
182
188
183
189
if (num_expand_dims_var > 0 ) {
184
190
nvinfer1::Dims reshape_dims_var;
185
191
reshape_dims_var.nbDims = repeats_var.size ();
186
192
for (int i = 0 ; i < num_expand_dims_var; i++) {
187
- reshape_dims_var.d [repeats_var.size () - 1 - i ] = 1 ;
193
+ reshape_dims_var.d [repeats_var.size () - 1 - i] = 1 ;
188
194
}
189
195
for (int i = 0 ; i < mean_layer_var_out_dims.nbDims ; i++) {
190
196
reshape_dims_var.d [i] = mean_layer_var_out_dims.d [i];
@@ -195,7 +201,6 @@ auto layer_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().
195
201
reshape_layer_var->setReshapeDimensions (reshape_dims_var);
196
202
mean_layer_var_out = reshape_layer_var->getOutput (0 );
197
203
}
198
-
199
204
200
205
for (int i = repeats_var.size () - 1 ; i >= 0 ; --i) {
201
206
std::vector<nvinfer1::ITensor*> tensors_vec;
@@ -207,34 +212,41 @@ auto layer_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().
207
212
mean_layer_var_out = concat_layer->getOutput (0 );
208
213
}
209
214
210
- // add eps
215
+ // add eps
211
216
auto eps_tensor = tensor_to_const (ctx, torch::tensor ({eps}));
212
- auto add = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kSUM , mean_layer_var_out, eps_tensor, (util::node_info (n)+" _add" ).c_str ());
217
+ auto add = add_elementwise (
218
+ ctx,
219
+ nvinfer1::ElementWiseOperation::kSUM ,
220
+ mean_layer_var_out,
221
+ eps_tensor,
222
+ (util::node_info (n) + " _add" ).c_str ());
213
223
TRTORCH_CHECK (add, " Unable to create Add layer from node: " << *n);
214
- add->setName ((util::node_info (n)+ " _add" ).c_str ());
224
+ add->setName ((util::node_info (n) + " _add" ).c_str ());
215
225
auto add_out = add->getOutput (0 );
216
226
217
227
// add Unary layer for sqrt((var + eps))
218
228
auto unary = ctx->net ->addUnary (*add_out, nvinfer1::UnaryOperation::kSQRT );
219
- TRTORCH_CHECK (unary, " Unable to create unary layer from node: " << *n);
220
- unary->setName ((util::node_info (n)+ " _unary_sqrt" ).c_str ());
229
+ TRTORCH_CHECK (unary, " Unable to create unary layer from node: " << *n);
230
+ unary->setName ((util::node_info (n) + " _unary_sqrt" ).c_str ());
221
231
auto unary_out = unary->getOutput (0 );
222
232
223
-
224
233
// (x - E[x]) / sqrt((var + eps))
225
- auto div= add_elementwise (ctx, nvinfer1::ElementWiseOperation::kDIV , xsubmean, unary_out, (util::node_info (n)+" _div" ).c_str ());
234
+ auto div = add_elementwise (
235
+ ctx, nvinfer1::ElementWiseOperation::kDIV , xsubmean, unary_out, (util::node_info (n) + " _div" ).c_str ());
226
236
TRTORCH_CHECK (div, " Unable to create div layer from node: " << *n);
227
- div->setName ((util::node_info (n)+ " _div" ).c_str ());
237
+ div->setName ((util::node_info (n) + " _div" ).c_str ());
228
238
auto div_out = div->getOutput (0 );
229
239
230
- // Set up gamma_weights and beta_weights from gamma_expand and beta_expand
240
+ // Set up gamma_weights and beta_weights from gamma_expand and
241
+ // beta_expand
231
242
auto gamma_weights = Weights (ctx, gamma_expand);
232
243
auto beta_weights = Weights (ctx, beta_expand);
233
244
234
245
auto power = Weights (ctx, at::ones_like (gamma_expand));
235
- auto scale_nd = ctx->net ->addScaleNd (*div_out, nvinfer1::ScaleMode::kELEMENTWISE , beta_weights.data , gamma_weights.data , power.data , 1 );
246
+ auto scale_nd = ctx->net ->addScaleNd (
247
+ *div_out, nvinfer1::ScaleMode::kELEMENTWISE , beta_weights.data , gamma_weights.data , power.data , 1 );
236
248
237
- scale_nd->setName ((util::node_info (n)+ " _scale_nd" ).c_str ());
249
+ scale_nd->setName ((util::node_info (n) + " _scale_nd" ).c_str ());
238
250
auto scale_nd_out = scale_nd->getOutput (0 );
239
251
240
252
ctx->AssociateValueAndTensor (n->outputs ()[0 ], scale_nd_out);
0 commit comments