Skip to content

Commit 1b554a1

Browse files
committed
Remove redundent code in layer_norm and generalize layer_norm to accept gamma and beta as None.
Signed-off-by: Yu-Te Cheng <[email protected]>
1 parent b3be1a3 commit 1b554a1

File tree

3 files changed

+87
-122
lines changed

3 files changed

+87
-122
lines changed

core/conversion/converters/impl/layer_norm.cpp

Lines changed: 55 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -81,23 +81,12 @@ auto layer_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().
8181

8282
/* Layer_Norm normalizes over last N dimensions.
8383
normalizaed_shape could be (C,H,W), (H,W), or (W). */
84-
8584
auto normalized_shape = args[1].unwrapToIntList();
8685
auto normalized_shape_vec = util::toVec(util::toDims(normalized_shape));
8786

88-
torch::Tensor gamma, beta;
89-
gamma = args[2].unwrapToTensor();
90-
beta = args[3].unwrapToTensor();
91-
92-
// Remove batch dimension from input shape for expand_size, which will
93-
// be used to create weights for addScaleNd later.
94-
auto expand_size = shape;
95-
expand_size.erase(expand_size.begin(), expand_size.begin() + 1);
96-
auto gamma_expand = gamma.expand(expand_size);
97-
auto beta_expand = beta.expand(expand_size);
98-
9987
// Unwrap eps.
10088
auto eps = args[4].unwrapToDouble();
89+
10190
LOG_DEBUG("cudnn disregarded");
10291

10392
// Set up axis_ask for E[x].
@@ -108,144 +97,89 @@ auto layer_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().
10897
LOG_DEBUG("Axis Mask for E[x]" << std::bitset<32>(axis_mask));
10998

11099
// E[x]
111-
auto mean_layer_expected = ctx->net->addReduce(*input, nvinfer1::ReduceOperation::kAVG, axis_mask, false);
112-
TRTORCH_CHECK(mean_layer_expected, "Unable to create mean_layer_expected from node: " << *n);
113-
mean_layer_expected->setName((util::node_info(n) + "_mean_expected").c_str());
114-
auto mean_layer_expected_out = mean_layer_expected->getOutput(0);
115-
116-
// Expand output of E[x] to the same shape as original input.
117-
c10::List<int64_t> repeats_expected;
118-
for (size_t i = 0; i < shape.size(); i++) {
119-
auto repeat = i > (shape.size() - normalized_shape_vec.size() - 1) ? shape[i] : 1;
120-
repeats_expected.push_back(repeat);
121-
}
122-
123-
int repeats_expected_rank = repeats_expected.size();
124-
auto mean_layer_expected_out_dims = mean_layer_expected_out->getDimensions();
125-
auto num_expand_dims_expected = repeats_expected_rank - mean_layer_expected_out_dims.nbDims;
126-
127-
if (num_expand_dims_expected > 0) {
128-
nvinfer1::Dims reshape_expected_dims;
129-
reshape_expected_dims.nbDims = repeats_expected.size();
130-
for (int i = 0; i < num_expand_dims_expected; i++) {
131-
reshape_expected_dims.d[repeats_expected.size() - 1 - i] = 1;
132-
}
133-
for (int i = 0; i < mean_layer_expected_out_dims.nbDims; i++) {
134-
reshape_expected_dims.d[i] = mean_layer_expected_out_dims.d[i];
135-
}
136-
// Add a reshape layer to expand dims
137-
auto reshape_layer_expected = ctx->net->addShuffle(*mean_layer_expected_out);
138-
reshape_layer_expected->setReshapeDimensions(reshape_expected_dims);
139-
mean_layer_expected_out = reshape_layer_expected->getOutput(0);
140-
}
141-
142-
for (int i = repeats_expected.size() - 1; i >= 0; --i) {
143-
std::vector<nvinfer1::ITensor*> tensors_vec;
144-
for (int j = 0; j < repeats_expected[i]; j++) {
145-
tensors_vec.push_back(mean_layer_expected_out);
146-
}
147-
auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size());
148-
concat_layer->setAxis(i);
149-
mean_layer_expected_out = concat_layer->getOutput(0);
150-
}
100+
auto mean_expected = ctx->net->addReduce(*input, nvinfer1::ReduceOperation::kAVG, axis_mask, true);
101+
TRTORCH_CHECK(mean_expected, "Unable to create mean_expected from node: " << *n);
102+
mean_expected->setName((util::node_info(n) + "_mean_expected").c_str());
103+
auto mean_expected_out = mean_expected->getOutput(0);
151104

152105
// X-E[x]
153106
auto sub = add_elementwise(
154-
ctx,
155-
nvinfer1::ElementWiseOperation::kSUB,
156-
input,
157-
mean_layer_expected_out,
158-
(util::node_info(n) + "_sub").c_str());
159-
TRTORCH_CHECK(sub, "Unable to create Add layer from node: " << *n);
107+
ctx, nvinfer1::ElementWiseOperation::kSUB, input, mean_expected_out, (util::node_info(n) + "_sub").c_str());
108+
TRTORCH_CHECK(sub, "Unable to create Sub layer from node: " << *n);
160109
sub->setName((util::node_info(n) + "_sub").c_str());
161-
auto xsubmean = sub->getOutput(0);
110+
auto xsubmean_out = sub->getOutput(0);
162111

163-
// Variance
112+
// Variance = mean(pow(xsubmean,2))
164113
float pow_scalar = 2;
165114
auto exponent = tensor_to_const(ctx, torch::tensor({pow_scalar}));
166115
auto pow = add_elementwise(
167-
ctx, nvinfer1::ElementWiseOperation::kPOW, xsubmean, exponent, (util::node_info(n) + "_pow").c_str());
168-
TRTORCH_CHECK(pow, "Unable to create Power layer from node: " << *n);
116+
ctx, nvinfer1::ElementWiseOperation::kPOW, xsubmean_out, exponent, (util::node_info(n) + "_pow").c_str());
117+
TRTORCH_CHECK(pow, "Unable to create Pow layer from node: " << *n);
169118
pow->setName((util::node_info(n) + "_pow").c_str());
170119
auto pow_out = pow->getOutput(0);
171120

172-
auto mean_layer_var = ctx->net->addReduce(*pow_out, nvinfer1::ReduceOperation::kAVG, axis_mask, false);
173-
TRTORCH_CHECK(mean_layer_var, "Unable to create mean_layer_var from node: " << *n);
174-
mean_layer_var->setName((util::node_info(n) + "_mean_var").c_str());
175-
auto mean_layer_var_out = mean_layer_var->getOutput(0);
176-
177-
// Expand output of mean_layer_var to the same shape as original
178-
// input.
179-
c10::List<int64_t> repeats_var;
180-
for (size_t i = 0; i < shape.size(); i++) {
181-
auto repeat = i > (shape.size() - normalized_shape_vec.size() - 1) ? shape[i] : 1;
182-
repeats_var.push_back(repeat);
183-
}
184-
185-
int repeats_var_rank = repeats_var.size();
186-
auto mean_layer_var_out_dims = mean_layer_var_out->getDimensions();
187-
auto num_expand_dims_var = repeats_var_rank - mean_layer_var_out_dims.nbDims;
188-
189-
if (num_expand_dims_var > 0) {
190-
nvinfer1::Dims reshape_dims_var;
191-
reshape_dims_var.nbDims = repeats_var.size();
192-
for (int i = 0; i < num_expand_dims_var; i++) {
193-
reshape_dims_var.d[repeats_var.size() - 1 - i] = 1;
194-
}
195-
for (int i = 0; i < mean_layer_var_out_dims.nbDims; i++) {
196-
reshape_dims_var.d[i] = mean_layer_var_out_dims.d[i];
197-
}
198-
199-
// Add a reshape layer to expand dims
200-
auto reshape_layer_var = ctx->net->addShuffle(*mean_layer_var_out);
201-
reshape_layer_var->setReshapeDimensions(reshape_dims_var);
202-
mean_layer_var_out = reshape_layer_var->getOutput(0);
203-
}
121+
auto mean_var = ctx->net->addReduce(*pow_out, nvinfer1::ReduceOperation::kAVG, axis_mask, true);
122+
TRTORCH_CHECK(mean_var, "Unable to create mean_var from node: " << *n);
123+
mean_var->setName((util::node_info(n) + "_mean_var").c_str());
124+
auto mean_var_out = mean_var->getOutput(0);
204125

205-
for (int i = repeats_var.size() - 1; i >= 0; --i) {
206-
std::vector<nvinfer1::ITensor*> tensors_vec;
207-
for (int j = 0; j < repeats_var[i]; j++) {
208-
tensors_vec.push_back(mean_layer_var_out);
209-
}
210-
auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size());
211-
concat_layer->setAxis(i);
212-
mean_layer_var_out = concat_layer->getOutput(0);
213-
}
214-
215-
// add eps
126+
// Variance + eps
216127
auto eps_tensor = tensor_to_const(ctx, torch::tensor({eps}));
217128
auto add = add_elementwise(
218-
ctx,
219-
nvinfer1::ElementWiseOperation::kSUM,
220-
mean_layer_var_out,
221-
eps_tensor,
222-
(util::node_info(n) + "_add").c_str());
129+
ctx, nvinfer1::ElementWiseOperation::kSUM, mean_var_out, eps_tensor, (util::node_info(n) + "_add").c_str());
223130
TRTORCH_CHECK(add, "Unable to create Add layer from node: " << *n);
224131
add->setName((util::node_info(n) + "_add").c_str());
225132
auto add_out = add->getOutput(0);
226133

227-
// add Unary layer for sqrt((var + eps))
228-
auto unary = ctx->net->addUnary(*add_out, nvinfer1::UnaryOperation::kSQRT);
229-
TRTORCH_CHECK(unary, "Unable to create unary layer from node: " << *n);
230-
unary->setName((util::node_info(n) + "_unary_sqrt").c_str());
231-
auto unary_out = unary->getOutput(0);
134+
// SQRT((Var + eps))
135+
auto sqrt = ctx->net->addUnary(*add_out, nvinfer1::UnaryOperation::kSQRT);
136+
TRTORCH_CHECK(sqrt, "Unable to create unary(sqrt) from node: " << *n);
137+
sqrt->setName((util::node_info(n) + "_sqrt").c_str());
138+
auto sqrt_out = sqrt->getOutput(0);
232139

233140
// (x - E[x]) / sqrt((var + eps))
234141
auto div = add_elementwise(
235-
ctx, nvinfer1::ElementWiseOperation::kDIV, xsubmean, unary_out, (util::node_info(n) + "_div").c_str());
142+
ctx, nvinfer1::ElementWiseOperation::kDIV, xsubmean_out, sqrt_out, (util::node_info(n) + "_div").c_str());
236143
TRTORCH_CHECK(div, "Unable to create div layer from node: " << *n);
237144
div->setName((util::node_info(n) + "_div").c_str());
238145
auto div_out = div->getOutput(0);
239146

147+
if (!args[2].IValue()->isTensor() && !args[3].IValue()->isTensor()) {
148+
ctx->AssociateValueAndTensor(n->outputs()[0], div_out);
149+
return true;
150+
}
151+
152+
// Remove batch dimension from input shape for expand_size, which will
153+
// be used to create weights for addScaleNd later.
154+
auto expand_size = shape;
155+
expand_size.erase(expand_size.begin(), expand_size.begin() + 1);
156+
240157
// Set up gamma_weights and beta_weights from gamma_expand and
241-
// beta_expand
242-
auto gamma_weights = Weights(ctx, gamma_expand);
243-
auto beta_weights = Weights(ctx, beta_expand);
158+
// beta_expand.
159+
auto gamma_weights = Weights(ctx, at::ones(expand_size));
160+
auto beta_weights = Weights(ctx, at::zeros(expand_size));
161+
162+
if (args[2].IValue()->isTensor()) {
163+
torch::Tensor gamma;
164+
gamma = args[2].unwrapToTensor();
165+
auto gamma_expand = gamma.expand(expand_size);
166+
gamma_weights = Weights(ctx, gamma_expand);
167+
} else {
168+
gamma_weights = Weights(ctx, at::ones(expand_size));
169+
}
244170

245-
auto power = Weights(ctx, at::ones_like(gamma_expand));
171+
if (args[3].IValue()->isTensor()) {
172+
torch::Tensor beta;
173+
beta = args[3].unwrapToTensor();
174+
auto beta_expand = beta.expand(expand_size);
175+
beta_weights = Weights(ctx, beta_expand);
176+
} else {
177+
beta_weights = Weights(ctx, at::zeros(expand_size));
178+
}
179+
180+
auto power = Weights(ctx, at::ones(expand_size));
246181
auto scale_nd = ctx->net->addScaleNd(
247182
*div_out, nvinfer1::ScaleMode::kELEMENTWISE, beta_weights.data, gamma_weights.data, power.data, 1);
248-
249183
scale_nd->setName((util::node_info(n) + "_scale_nd").c_str());
250184
auto scale_nd_out = scale_nd->getOutput(0);
251185

tests/core/conversion/converters/BUILD

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ converter_test(
3636
)
3737

3838
converter_test(
39-
name = "test_layer_norm",
39+
name = "test_layer_norm",
4040
)
4141

4242
converter_test(
@@ -114,6 +114,7 @@ test_suite(
114114
":test_element_wise",
115115
":test_expand",
116116
":test_interpolate",
117+
":test_layer_norm",
117118
":test_linear",
118119
":test_lstm_cell",
119120
":test_matrix_multiply",

tests/core/conversion/converters/test_layer_norm.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,36 @@
44
#include "tests/util/util.h"
55
#include "torch/csrc/jit/ir/irparser.h"
66

7+
TEST(Converters, ATenLayerNormConvertsCorrectlyLast3DimsNoGammaBeta) {
8+
const auto graph = R"IR(
9+
graph(%0 : Tensor):
10+
%gamma : None = prim::Constant()
11+
%beta : None = prim::Constant()
12+
%1: int = prim::Constant[value=3]()
13+
%2: int = prim::Constant[value=100]()
14+
%3: int = prim::Constant[value=100]()
15+
%4 : int[] = prim::ListConstruct(%1, %2, %3)
16+
%7 : bool = prim::Constant[value=0]()
17+
%8 : float = prim::Constant[value=1.0000000000000001e-05]()
18+
%9 : Tensor = aten::layer_norm(%0, %4, %gamma, %beta, %8, %7)
19+
return (%9))IR";
20+
21+
auto g = std::make_shared<torch::jit::Graph>();
22+
torch::jit::parseIR(graph, g.get());
23+
24+
auto in = at::randint(1, 10, {4, 3, 100, 100}, {at::kCUDA});
25+
26+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
27+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
28+
29+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
30+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
31+
32+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
33+
34+
}
35+
36+
737
TEST(Converters, ATenLayerNormConvertsCorrectlyLast3Dims) {
838
const auto graph = R"IR(
939
graph(%0 : Tensor,

0 commit comments

Comments
 (0)