Skip to content

Commit f04e3be

Browse files
committed
Support aten::layer_norm
Signed-off-by: Yu-Te Cheng <[email protected]>
1 parent 52947fe commit f04e3be

File tree

4 files changed

+346
-0
lines changed

4 files changed

+346
-0
lines changed

core/conversion/converters/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ cc_library(
4040
"impl/element_wise.cpp",
4141
"impl/expand.cpp",
4242
"impl/interpolate.cpp",
43+
"impl/layer_norm.cpp",
4344
"impl/linear.cpp",
4445
"impl/lstm_cell.cpp",
4546
"impl/matrix_multiply.cpp",
Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
#include "core/conversion/converters/converters.h"
2+
#include "core/util/prelude.h"
3+
#include "torch/torch.h"
4+
5+
namespace trtorch {
6+
namespace core {
7+
namespace conversion {
8+
namespace converters {
9+
namespace impl {
10+
namespace {
11+
12+
nvinfer1::ILayer* add_elementwise(
13+
ConversionCtx* ctx,
14+
nvinfer1::ElementWiseOperation op,
15+
nvinfer1::ITensor* self,
16+
nvinfer1::ITensor* other,
17+
const std::string& name) {
18+
// ensure self to have larger number of dimension
19+
bool swapSelfOther = false;
20+
if (self->getDimensions().nbDims < other->getDimensions().nbDims) {
21+
std::swap(self, other);
22+
swapSelfOther = true;
23+
}
24+
auto selfDim = util::toVec(self->getDimensions());
25+
auto otherDim = util::toVec(other->getDimensions());
26+
if (selfDim.size() != otherDim.size()) {
27+
// other is with dynamic shape, need to expand its dimension now and get its shape at runtime
28+
if (otherDim.end() != std::find(otherDim.begin(), otherDim.end(), -1)) {
29+
auto thOtherStaticShapeMask = torch::ones(selfDim.size(), torch::kInt32);
30+
auto thOtherDynamicShapeMask = torch::zeros(selfDim.size(), torch::kInt32);
31+
for (size_t start = selfDim.size() - otherDim.size(), idx = 0; idx < otherDim.size(); ++idx) {
32+
if (-1 != otherDim[idx]) {
33+
thOtherStaticShapeMask[start + idx] = otherDim[idx];
34+
} else {
35+
thOtherStaticShapeMask[start + idx] = 0;
36+
thOtherDynamicShapeMask[start + idx] = 1;
37+
}
38+
}
39+
auto otherStaticShapeMask = tensor_to_const(ctx, thOtherStaticShapeMask);
40+
auto otherDynamicShapeMask = tensor_to_const(ctx, thOtherDynamicShapeMask);
41+
auto selfShape = ctx->net->addShape(*self)->getOutput(0);
42+
// size of dynamic dimension of other need to the same as that of corresponding dimension of self
43+
auto otherDynamicShape =
44+
ctx->net->addElementWise(*selfShape, *otherDynamicShapeMask, nvinfer1::ElementWiseOperation::kPROD)
45+
->getOutput(0);
46+
auto targetOtherShape =
47+
ctx->net->addElementWise(*otherDynamicShape, *otherStaticShapeMask, nvinfer1::ElementWiseOperation::kSUM)
48+
->getOutput(0);
49+
50+
auto otherShuffle = ctx->net->addShuffle(*other);
51+
otherShuffle->setName(std::string("Reshape other tensor to have the same nDim as self for " + name).c_str());
52+
otherShuffle->setInput(1, *targetOtherShape);
53+
other = otherShuffle->getOutput(0);
54+
} else {
55+
// other is with static shape, expand dimension to make tow tensor have the same number of dimension
56+
auto otherShuffle = ctx->net->addShuffle(*other);
57+
otherShuffle->setReshapeDimensions(util::toDimsPad(otherDim, selfDim.size()));
58+
other = otherShuffle->getOutput(0);
59+
}
60+
}
61+
if (swapSelfOther) {
62+
// swap back
63+
std::swap(self, other);
64+
swapSelfOther = false;
65+
}
66+
auto ele = ctx->net->addElementWise(*self, *other, op);
67+
ele->setName(name.c_str());
68+
return ele;
69+
}
70+
71+
72+
auto layer_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern({
73+
R"SIG(aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? gamma, Tensor? beta,
74+
float eps, bool cudnn_enabled) -> (Tensor))SIG",
75+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
76+
77+
auto input = args[0].ITensor(); // assumes non-static input Tensor
78+
auto orig_shape = input->getDimensions();
79+
auto shape = util::toVec(orig_shape);
80+
81+
/* Layer_Norm normalizes over last N dimensions.
82+
normalizaed_shape could be (C,H,W), (H,W), or (W). */
83+
84+
auto normalized_shape = args[1].unwrapToIntList();
85+
auto normalized_shape_vec = util::toVec(util::toDims(normalized_shape));
86+
87+
88+
torch::Tensor gamma, beta;
89+
gamma = args[2].unwrapToTensor();
90+
beta = args[3].unwrapToTensor();
91+
92+
// Remove batch dimension from input shape for expand_size, which will be used to create weights for addScaleNd later.
93+
auto expand_size = shape;
94+
expand_size.erase(expand_size.begin(), expand_size.begin() + 1);
95+
auto gamma_expand = gamma.expand(expand_size);
96+
auto beta_expand = beta.expand(expand_size);
97+
98+
// Unwrap eps.
99+
auto eps = args[4].unwrapToDouble();
100+
LOG_DEBUG("cudnn disregarded");
101+
102+
// Set up axis_ask for E[x].
103+
uint32_t axis_mask = 0;
104+
for (size_t i = 0; i < normalized_shape_vec.size(); i++) {
105+
axis_mask |= 1 << (shape.size() - i - 1);
106+
}
107+
LOG_DEBUG("Axis Mask for E[x]" << std::bitset<32>(axis_mask));
108+
109+
// E[x]
110+
auto mean_layer_expected = ctx->net->addReduce(*input, nvinfer1::ReduceOperation::kAVG, axis_mask, false);
111+
TRTORCH_CHECK(mean_layer_expected, "Unable to create mean_layer_expected from node: " << *n);
112+
mean_layer_expected->setName((util::node_info(n) + "_mean_expected").c_str());
113+
auto mean_layer_expected_out = mean_layer_expected->getOutput(0);
114+
115+
// Expand output of E[x] to the same shape as original input.
116+
c10::List<int64_t> repeats_expected;
117+
for (size_t i = 0; i < shape.size(); i++) {
118+
auto repeat = i > (shape.size() - normalized_shape_vec.size() - 1) ? shape[i] : 1;
119+
repeats_expected.push_back(repeat);
120+
}
121+
122+
int repeats_expected_rank = repeats_expected.size(); // 4
123+
auto mean_layer_expected_out_dims = mean_layer_expected_out->getDimensions(); // 1
124+
auto num_expand_dims_expected = repeats_expected_rank - mean_layer_expected_out_dims.nbDims; // 3
125+
126+
if (num_expand_dims_expected > 0) {
127+
nvinfer1::Dims reshape_expected_dims;
128+
reshape_expected_dims.nbDims = repeats_expected.size();
129+
for (int i = 0; i < num_expand_dims_expected; i++) {
130+
reshape_expected_dims.d[repeats_expected.size() - 1 - i ] = 1;
131+
}
132+
for (int i = 0; i < mean_layer_expected_out_dims.nbDims; i++) {
133+
reshape_expected_dims.d[i] = mean_layer_expected_out_dims.d[i];
134+
}
135+
// Add a reshape layer to expand dims
136+
auto reshape_layer_expected = ctx->net->addShuffle(*mean_layer_expected_out);
137+
reshape_layer_expected->setReshapeDimensions(reshape_expected_dims);
138+
mean_layer_expected_out = reshape_layer_expected->getOutput(0);
139+
}
140+
141+
for (int i = repeats_expected.size() - 1; i >= 0; --i) {
142+
std::vector<nvinfer1::ITensor*> tensors_vec;
143+
for (int j = 0; j < repeats_expected[i]; j++) {
144+
tensors_vec.push_back(mean_layer_expected_out);
145+
}
146+
auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size());
147+
concat_layer->setAxis(i);
148+
mean_layer_expected_out = concat_layer->getOutput(0);
149+
}
150+
151+
152+
// X-E[x]
153+
auto sub = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUB, input, mean_layer_expected_out, (util::node_info(n) + "_sub").c_str());
154+
TRTORCH_CHECK(sub, "Unable to create Add layer from node: " << *n);
155+
sub->setName((util::node_info(n) + "_sub").c_str());
156+
auto xsubmean = sub->getOutput(0);
157+
158+
// Variance
159+
float pow_scalar = 2;
160+
auto exponent = tensor_to_const(ctx, torch::tensor({pow_scalar}));
161+
auto pow = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPOW, xsubmean, exponent, (util::node_info(n) +"_pow").c_str());
162+
TRTORCH_CHECK(pow, "Unable to create Power layer from node: " << *n);
163+
pow->setName((util::node_info(n) +"_pow").c_str());
164+
auto pow_out = pow->getOutput(0);
165+
166+
auto mean_layer_var = ctx->net->addReduce(*pow_out, nvinfer1::ReduceOperation::kAVG, axis_mask, false);
167+
TRTORCH_CHECK(mean_layer_var, "Unable to create mean_layer_var from node: " << *n);
168+
mean_layer_var->setName((util::node_info(n) + "_mean_var").c_str());
169+
auto mean_layer_var_out = mean_layer_var->getOutput(0);
170+
171+
// Expand output of mean_layer_var to the same shape as original input.
172+
c10::List<int64_t> repeats_var;
173+
for (size_t i = 0; i < shape.size(); i++) {
174+
auto repeat = i > (shape.size()-normalized_shape_vec.size()-1) ? shape[i] : 1;
175+
repeats_var.push_back(repeat);
176+
}
177+
178+
int repeats_var_rank = repeats_var.size();
179+
auto mean_layer_var_out_dims = mean_layer_var_out->getDimensions();
180+
auto num_expand_dims_var = repeats_var_rank - mean_layer_var_out_dims.nbDims;
181+
182+
183+
if (num_expand_dims_var > 0) {
184+
nvinfer1::Dims reshape_dims_var;
185+
reshape_dims_var.nbDims = repeats_var.size();
186+
for (int i = 0; i < num_expand_dims_var; i++) {
187+
reshape_dims_var.d[repeats_var.size() - 1 - i ] = 1;
188+
}
189+
for (int i = 0; i < mean_layer_var_out_dims.nbDims; i++) {
190+
reshape_dims_var.d[i] = mean_layer_var_out_dims.d[i];
191+
}
192+
193+
// Add a reshape layer to expand dims
194+
auto reshape_layer_var = ctx->net->addShuffle(*mean_layer_var_out);
195+
reshape_layer_var->setReshapeDimensions(reshape_dims_var);
196+
mean_layer_var_out = reshape_layer_var->getOutput(0);
197+
}
198+
199+
200+
for (int i = repeats_var.size() - 1; i >= 0; --i) {
201+
std::vector<nvinfer1::ITensor*> tensors_vec;
202+
for (int j = 0; j < repeats_var[i]; j++) {
203+
tensors_vec.push_back(mean_layer_var_out);
204+
}
205+
auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size());
206+
concat_layer->setAxis(i);
207+
mean_layer_var_out = concat_layer->getOutput(0);
208+
}
209+
210+
// add eps
211+
auto eps_tensor = tensor_to_const(ctx, torch::tensor({eps}));
212+
auto add = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUM, mean_layer_var_out, eps_tensor, (util::node_info(n)+"_add").c_str());
213+
TRTORCH_CHECK(add, "Unable to create Add layer from node: " << *n);
214+
add->setName((util::node_info(n)+"_add").c_str());
215+
auto add_out = add->getOutput(0);
216+
217+
// add Unary layer for sqrt((var + eps))
218+
auto unary = ctx->net->addUnary(*add_out, nvinfer1::UnaryOperation::kSQRT);
219+
TRTORCH_CHECK(unary, "Unable to create unary layer from node: " << *n);
220+
unary->setName((util::node_info(n)+"_unary_sqrt").c_str());
221+
auto unary_out = unary->getOutput(0);
222+
223+
224+
// (x - E[x]) / sqrt((var + eps))
225+
auto div= add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, xsubmean, unary_out, (util::node_info(n)+"_div").c_str());
226+
TRTORCH_CHECK(div, "Unable to create div layer from node: " << *n);
227+
div->setName((util::node_info(n)+"_div").c_str());
228+
auto div_out = div->getOutput(0);
229+
230+
// Set up gamma_weights and beta_weights from gamma_expand and beta_expand
231+
auto gamma_weights = Weights(ctx, gamma_expand);
232+
auto beta_weights = Weights(ctx, beta_expand);
233+
234+
auto power = Weights(ctx, at::ones_like(gamma_expand));
235+
auto scale_nd = ctx->net->addScaleNd(*div_out, nvinfer1::ScaleMode::kELEMENTWISE, beta_weights.data, gamma_weights.data, power.data, 1);
236+
237+
scale_nd->setName((util::node_info(n)+"_scale_nd").c_str());
238+
auto scale_nd_out = scale_nd->getOutput(0);
239+
240+
ctx->AssociateValueAndTensor(n->outputs()[0], scale_nd_out);
241+
return true;
242+
}});
243+
244+
} // namespace
245+
} // namespace impl
246+
} // namespace converters
247+
} // namespace conversion
248+
} // namespace core
249+
} // namespace trtorch

tests/core/conversion/converters/BUILD

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ converter_test(
3535
name = "test_expand",
3636
)
3737

38+
converter_test(
39+
name = "test_layer_norm",
40+
)
41+
3842
converter_test(
3943
name = "test_linear",
4044
)
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "gtest/gtest.h"
4+
#include "tests/util/util.h"
5+
#include "torch/csrc/jit/ir/irparser.h"
6+
7+
TEST(Converters, ATenLayerNormConvertsCorrectlyLast3Dims) {
8+
const auto graph = R"IR(
9+
graph(%0 : Tensor,
10+
%gamma: Float(3, 100, 100),
11+
%beta: Float(3, 100, 100)):
12+
%1: int = prim::Constant[value=3]()
13+
%2: int = prim::Constant[value=100]()
14+
%3: int = prim::Constant[value=100]()
15+
%4 : int[] = prim::ListConstruct(%1, %2, %3)
16+
%7 : bool = prim::Constant[value=0]()
17+
%8 : float = prim::Constant[value=1.0000000000000001e-05]()
18+
%9 : Tensor = aten::layer_norm(%0, %4, %gamma, %beta, %8, %7)
19+
return (%9))IR";
20+
21+
auto g = std::make_shared<torch::jit::Graph>();
22+
torch::jit::parseIR(graph, g.get());
23+
24+
auto in = at::randint(1, 10, {4, 3, 100, 100}, {at::kCUDA});
25+
auto gamma = at::randint(1, 10, {3, 100, 100}, {at::kCUDA});
26+
auto beta = at::randint(1, 10, {3, 100, 100}, {at::kCUDA});
27+
28+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {gamma, beta});
29+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
30+
31+
params = trtorch::core::conversion::get_named_params(g->inputs(), {gamma, beta});
32+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
33+
34+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
35+
}
36+
37+
TEST(Converters, ATenLayerNormConvertsCorrectlyLast2Dims) {
38+
const auto graph = R"IR(
39+
graph(%0 : Tensor,
40+
%gamma : Float(100, 100),
41+
%beta : Float(100, 100)):
42+
%2: int = prim::Constant[value=100]()
43+
%3: int = prim::Constant[value=100]()
44+
%4 : int[] = prim::ListConstruct(%2, %3)
45+
%7 : bool = prim::Constant[value=0]()
46+
%8 : float = prim::Constant[value=1.0000000000000001e-05]()
47+
%9 : Tensor = aten::layer_norm(%0, %4, %gamma, %beta, %8, %7)
48+
return (%9))IR";
49+
50+
auto g = std::make_shared<torch::jit::Graph>();
51+
torch::jit::parseIR(graph, g.get());
52+
53+
auto in = at::randint(1, 10, {4, 3, 100, 100}, {at::kCUDA});
54+
auto gamma = at::randint(1, 10, {100, 100}, {at::kCUDA});
55+
auto beta = at::randint(1, 10, {100, 100}, {at::kCUDA});
56+
57+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {gamma, beta});
58+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
59+
60+
params = trtorch::core::conversion::get_named_params(g->inputs(), {gamma, beta});
61+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
62+
63+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
64+
}
65+
66+
TEST(Converters, ATenLayerNormConvertsCorrectlyLast1Dims) {
67+
const auto graph = R"IR(
68+
graph(%0 : Tensor,
69+
%gamma: Float(100),
70+
%beta: Float(100)):
71+
%3: int = prim::Constant[value=100]()
72+
%4 : int[] = prim::ListConstruct(%3)
73+
%7 : bool = prim::Constant[value=0]()
74+
%8 : float = prim::Constant[value=1.0000000000000001e-05]()
75+
%9 : Tensor = aten::layer_norm(%0, %4, %gamma, %beta, %8, %7)
76+
return (%9))IR";
77+
78+
auto g = std::make_shared<torch::jit::Graph>();
79+
torch::jit::parseIR(graph, g.get());
80+
81+
auto in = at::randint(1, 10, {4, 3, 100, 100}, {at::kCUDA});
82+
auto gamma = at::randint(1, 10, {100}, {at::kCUDA});
83+
auto beta = at::randint(1, 10, {100}, {at::kCUDA});
84+
85+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {gamma, beta});
86+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
87+
88+
params = trtorch::core::conversion::get_named_params(g->inputs(), {gamma, beta});
89+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
90+
91+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
92+
}

0 commit comments

Comments
 (0)