@@ -81,23 +81,12 @@ auto layer_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().
81
81
82
82
/* Layer_Norm normalizes over last N dimensions.
83
83
normalizaed_shape could be (C,H,W), (H,W), or (W). */
84
-
85
84
auto normalized_shape = args[1 ].unwrapToIntList ();
86
85
auto normalized_shape_vec = util::toVec (util::toDims (normalized_shape));
87
86
88
- torch::Tensor gamma, beta;
89
- gamma = args[2 ].unwrapToTensor ();
90
- beta = args[3 ].unwrapToTensor ();
91
-
92
- // Remove batch dimension from input shape for expand_size, which will
93
- // be used to create weights for addScaleNd later.
94
- auto expand_size = shape;
95
- expand_size.erase (expand_size.begin (), expand_size.begin () + 1 );
96
- auto gamma_expand = gamma.expand (expand_size);
97
- auto beta_expand = beta.expand (expand_size);
98
-
99
87
// Unwrap eps.
100
88
auto eps = args[4 ].unwrapToDouble ();
89
+
101
90
LOG_DEBUG (" cudnn disregarded" );
102
91
103
92
// Set up axis_ask for E[x].
@@ -108,144 +97,89 @@ auto layer_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().
108
97
LOG_DEBUG (" Axis Mask for E[x]" << std::bitset<32 >(axis_mask));
109
98
110
99
// E[x]
111
- auto mean_layer_expected = ctx->net ->addReduce (*input, nvinfer1::ReduceOperation::kAVG , axis_mask, false );
112
- TRTORCH_CHECK (mean_layer_expected, " Unable to create mean_layer_expected from node: " << *n);
113
- mean_layer_expected->setName ((util::node_info (n) + " _mean_expected" ).c_str ());
114
- auto mean_layer_expected_out = mean_layer_expected->getOutput (0 );
115
-
116
- // Expand output of E[x] to the same shape as original input.
117
- c10::List<int64_t > repeats_expected;
118
- for (size_t i = 0 ; i < shape.size (); i++) {
119
- auto repeat = i > (shape.size () - normalized_shape_vec.size () - 1 ) ? shape[i] : 1 ;
120
- repeats_expected.push_back (repeat);
121
- }
122
-
123
- int repeats_expected_rank = repeats_expected.size ();
124
- auto mean_layer_expected_out_dims = mean_layer_expected_out->getDimensions ();
125
- auto num_expand_dims_expected = repeats_expected_rank - mean_layer_expected_out_dims.nbDims ;
126
-
127
- if (num_expand_dims_expected > 0 ) {
128
- nvinfer1::Dims reshape_expected_dims;
129
- reshape_expected_dims.nbDims = repeats_expected.size ();
130
- for (int i = 0 ; i < num_expand_dims_expected; i++) {
131
- reshape_expected_dims.d [repeats_expected.size () - 1 - i] = 1 ;
132
- }
133
- for (int i = 0 ; i < mean_layer_expected_out_dims.nbDims ; i++) {
134
- reshape_expected_dims.d [i] = mean_layer_expected_out_dims.d [i];
135
- }
136
- // Add a reshape layer to expand dims
137
- auto reshape_layer_expected = ctx->net ->addShuffle (*mean_layer_expected_out);
138
- reshape_layer_expected->setReshapeDimensions (reshape_expected_dims);
139
- mean_layer_expected_out = reshape_layer_expected->getOutput (0 );
140
- }
141
-
142
- for (int i = repeats_expected.size () - 1 ; i >= 0 ; --i) {
143
- std::vector<nvinfer1::ITensor*> tensors_vec;
144
- for (int j = 0 ; j < repeats_expected[i]; j++) {
145
- tensors_vec.push_back (mean_layer_expected_out);
146
- }
147
- auto concat_layer = ctx->net ->addConcatenation (tensors_vec.data (), tensors_vec.size ());
148
- concat_layer->setAxis (i);
149
- mean_layer_expected_out = concat_layer->getOutput (0 );
150
- }
100
+ auto mean_expected = ctx->net ->addReduce (*input, nvinfer1::ReduceOperation::kAVG , axis_mask, true );
101
+ TRTORCH_CHECK (mean_expected, " Unable to create mean_expected from node: " << *n);
102
+ mean_expected->setName ((util::node_info (n) + " _mean_expected" ).c_str ());
103
+ auto mean_expected_out = mean_expected->getOutput (0 );
151
104
152
105
// X-E[x]
153
106
auto sub = add_elementwise (
154
- ctx,
155
- nvinfer1::ElementWiseOperation::kSUB ,
156
- input,
157
- mean_layer_expected_out,
158
- (util::node_info (n) + " _sub" ).c_str ());
159
- TRTORCH_CHECK (sub, " Unable to create Add layer from node: " << *n);
107
+ ctx, nvinfer1::ElementWiseOperation::kSUB , input, mean_expected_out, (util::node_info (n) + " _sub" ).c_str ());
108
+ TRTORCH_CHECK (sub, " Unable to create Sub layer from node: " << *n);
160
109
sub->setName ((util::node_info (n) + " _sub" ).c_str ());
161
- auto xsubmean = sub->getOutput (0 );
110
+ auto xsubmean_out = sub->getOutput (0 );
162
111
163
- // Variance
112
+ // Variance = mean(pow(xsubmean,2))
164
113
float pow_scalar = 2 ;
165
114
auto exponent = tensor_to_const (ctx, torch::tensor ({pow_scalar}));
166
115
auto pow = add_elementwise (
167
- ctx, nvinfer1::ElementWiseOperation::kPOW , xsubmean , exponent, (util::node_info (n) + " _pow" ).c_str ());
168
- TRTORCH_CHECK (pow, " Unable to create Power layer from node: " << *n);
116
+ ctx, nvinfer1::ElementWiseOperation::kPOW , xsubmean_out , exponent, (util::node_info (n) + " _pow" ).c_str ());
117
+ TRTORCH_CHECK (pow, " Unable to create Pow layer from node: " << *n);
169
118
pow->setName ((util::node_info (n) + " _pow" ).c_str ());
170
119
auto pow_out = pow->getOutput (0 );
171
120
172
- auto mean_layer_var = ctx->net ->addReduce (*pow_out, nvinfer1::ReduceOperation::kAVG , axis_mask, false );
173
- TRTORCH_CHECK (mean_layer_var, " Unable to create mean_layer_var from node: " << *n);
174
- mean_layer_var->setName ((util::node_info (n) + " _mean_var" ).c_str ());
175
- auto mean_layer_var_out = mean_layer_var->getOutput (0 );
176
-
177
- // Expand output of mean_layer_var to the same shape as original
178
- // input.
179
- c10::List<int64_t > repeats_var;
180
- for (size_t i = 0 ; i < shape.size (); i++) {
181
- auto repeat = i > (shape.size () - normalized_shape_vec.size () - 1 ) ? shape[i] : 1 ;
182
- repeats_var.push_back (repeat);
183
- }
184
-
185
- int repeats_var_rank = repeats_var.size ();
186
- auto mean_layer_var_out_dims = mean_layer_var_out->getDimensions ();
187
- auto num_expand_dims_var = repeats_var_rank - mean_layer_var_out_dims.nbDims ;
188
-
189
- if (num_expand_dims_var > 0 ) {
190
- nvinfer1::Dims reshape_dims_var;
191
- reshape_dims_var.nbDims = repeats_var.size ();
192
- for (int i = 0 ; i < num_expand_dims_var; i++) {
193
- reshape_dims_var.d [repeats_var.size () - 1 - i] = 1 ;
194
- }
195
- for (int i = 0 ; i < mean_layer_var_out_dims.nbDims ; i++) {
196
- reshape_dims_var.d [i] = mean_layer_var_out_dims.d [i];
197
- }
198
-
199
- // Add a reshape layer to expand dims
200
- auto reshape_layer_var = ctx->net ->addShuffle (*mean_layer_var_out);
201
- reshape_layer_var->setReshapeDimensions (reshape_dims_var);
202
- mean_layer_var_out = reshape_layer_var->getOutput (0 );
203
- }
121
+ auto mean_var = ctx->net ->addReduce (*pow_out, nvinfer1::ReduceOperation::kAVG , axis_mask, true );
122
+ TRTORCH_CHECK (mean_var, " Unable to create mean_var from node: " << *n);
123
+ mean_var->setName ((util::node_info (n) + " _mean_var" ).c_str ());
124
+ auto mean_var_out = mean_var->getOutput (0 );
204
125
205
- for (int i = repeats_var.size () - 1 ; i >= 0 ; --i) {
206
- std::vector<nvinfer1::ITensor*> tensors_vec;
207
- for (int j = 0 ; j < repeats_var[i]; j++) {
208
- tensors_vec.push_back (mean_layer_var_out);
209
- }
210
- auto concat_layer = ctx->net ->addConcatenation (tensors_vec.data (), tensors_vec.size ());
211
- concat_layer->setAxis (i);
212
- mean_layer_var_out = concat_layer->getOutput (0 );
213
- }
214
-
215
- // add eps
126
+ // Variance + eps
216
127
auto eps_tensor = tensor_to_const (ctx, torch::tensor ({eps}));
217
128
auto add = add_elementwise (
218
- ctx,
219
- nvinfer1::ElementWiseOperation::kSUM ,
220
- mean_layer_var_out,
221
- eps_tensor,
222
- (util::node_info (n) + " _add" ).c_str ());
129
+ ctx, nvinfer1::ElementWiseOperation::kSUM , mean_var_out, eps_tensor, (util::node_info (n) + " _add" ).c_str ());
223
130
TRTORCH_CHECK (add, " Unable to create Add layer from node: " << *n);
224
131
add->setName ((util::node_info (n) + " _add" ).c_str ());
225
132
auto add_out = add->getOutput (0 );
226
133
227
- // add Unary layer for sqrt((var + eps))
228
- auto unary = ctx->net ->addUnary (*add_out, nvinfer1::UnaryOperation::kSQRT );
229
- TRTORCH_CHECK (unary , " Unable to create unary layer from node: " << *n);
230
- unary ->setName ((util::node_info (n) + " _unary_sqrt " ).c_str ());
231
- auto unary_out = unary ->getOutput (0 );
134
+ // SQRT((Var + eps))
135
+ auto sqrt = ctx->net ->addUnary (*add_out, nvinfer1::UnaryOperation::kSQRT );
136
+ TRTORCH_CHECK (sqrt , " Unable to create unary(sqrt) from node: " << *n);
137
+ sqrt ->setName ((util::node_info (n) + " _sqrt " ).c_str ());
138
+ auto sqrt_out = sqrt ->getOutput (0 );
232
139
233
140
// (x - E[x]) / sqrt((var + eps))
234
141
auto div = add_elementwise (
235
- ctx, nvinfer1::ElementWiseOperation::kDIV , xsubmean, unary_out , (util::node_info (n) + " _div" ).c_str ());
142
+ ctx, nvinfer1::ElementWiseOperation::kDIV , xsubmean_out, sqrt_out , (util::node_info (n) + " _div" ).c_str ());
236
143
TRTORCH_CHECK (div, " Unable to create div layer from node: " << *n);
237
144
div->setName ((util::node_info (n) + " _div" ).c_str ());
238
145
auto div_out = div->getOutput (0 );
239
146
147
+ if (!args[2 ].IValue ()->isTensor () && !args[3 ].IValue ()->isTensor ()) {
148
+ ctx->AssociateValueAndTensor (n->outputs ()[0 ], div_out);
149
+ return true ;
150
+ }
151
+
152
+ // Remove batch dimension from input shape for expand_size, which will
153
+ // be used to create weights for addScaleNd later.
154
+ auto expand_size = shape;
155
+ expand_size.erase (expand_size.begin (), expand_size.begin () + 1 );
156
+
240
157
// Set up gamma_weights and beta_weights from gamma_expand and
241
- // beta_expand
242
- auto gamma_weights = Weights (ctx, gamma_expand);
243
- auto beta_weights = Weights (ctx, beta_expand);
158
+ // beta_expand.
159
+ auto gamma_weights = Weights (ctx, at::ones (expand_size));
160
+ auto beta_weights = Weights (ctx, at::zeros (expand_size));
161
+
162
+ if (args[2 ].IValue ()->isTensor ()) {
163
+ torch::Tensor gamma;
164
+ gamma = args[2 ].unwrapToTensor ();
165
+ auto gamma_expand = gamma.expand (expand_size);
166
+ gamma_weights = Weights (ctx, gamma_expand);
167
+ } else {
168
+ gamma_weights = Weights (ctx, at::ones (expand_size));
169
+ }
244
170
245
- auto power = Weights (ctx, at::ones_like (gamma_expand));
171
+ if (args[3 ].IValue ()->isTensor ()) {
172
+ torch::Tensor beta;
173
+ beta = args[3 ].unwrapToTensor ();
174
+ auto beta_expand = beta.expand (expand_size);
175
+ beta_weights = Weights (ctx, beta_expand);
176
+ } else {
177
+ beta_weights = Weights (ctx, at::zeros (expand_size));
178
+ }
179
+
180
+ auto power = Weights (ctx, at::ones (expand_size));
246
181
auto scale_nd = ctx->net ->addScaleNd (
247
182
*div_out, nvinfer1::ScaleMode::kELEMENTWISE , beta_weights.data , gamma_weights.data , power.data , 1 );
248
-
249
183
scale_nd->setName ((util::node_info (n) + " _scale_nd" ).c_str ());
250
184
auto scale_nd_out = scale_nd->getOutput (0 );
251
185
0 commit comments