Skip to content

Commit e30e5ac

Browse files
chore: Update layer_norm converter to use INormalizationLayer (#2509)
1 parent 5de2524 commit e30e5ac

File tree

2 files changed

+69
-129
lines changed

2 files changed

+69
-129
lines changed

core/conversion/converters/impl/layer_norm.cpp

Lines changed: 64 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -10,138 +10,84 @@ namespace converters {
1010
namespace impl {
1111
namespace {
1212

13+
nvinfer1::ITensor* broadcast(
14+
ConversionCtx* ctx,
15+
const torch::jit::Node* n,
16+
nvinfer1::ITensor* to_broadcast,
17+
const int nbDims,
18+
const std::string& tag) {
19+
auto to_broadcast_nbdims = to_broadcast->getDimensions().nbDims;
20+
TORCHTRT_CHECK(to_broadcast_nbdims <= nbDims, "Cannot broadcast tensor with more dimensions than the target");
21+
if (to_broadcast_nbdims == nbDims) {
22+
return to_broadcast;
23+
}
24+
auto shape_layer = ctx->net->addShape(*to_broadcast);
25+
TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n);
26+
shape_layer->setName((util::node_info(n) + "_shape_" + tag).c_str());
27+
auto shape_layer_out = shape_layer->getOutput(0);
28+
29+
auto extra_dims_tensor = torch::ones({nbDims - to_broadcast_nbdims}, torch::TensorOptions().dtype(torch::kInt32));
30+
auto extra_dims_itensor = tensor_to_const(ctx, extra_dims_tensor);
31+
32+
std::vector<nvinfer1::ITensor*> to_concat = {extra_dims_itensor, shape_layer_out};
33+
auto concat_layer = ctx->net->addConcatenation(to_concat.data(), to_concat.size());
34+
TORCHTRT_CHECK(concat_layer, "Unable to create concat layer from node: " << *n);
35+
concat_layer->setName((util::node_info(n) + "_concat_" + tag).c_str());
36+
auto target_shape = concat_layer->getOutput(0);
37+
38+
auto shuffle_layer = ctx->net->addShuffle(*to_broadcast);
39+
TORCHTRT_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n);
40+
shuffle_layer->setName((util::node_info(n) + "_shuffle_" + tag).c_str());
41+
shuffle_layer->setInput(1, *target_shape);
42+
auto output = shuffle_layer->getOutput(0);
43+
LOG_DEBUG(
44+
"Broadcast " << tag << " to shape: " << output->getDimensions() << " from " << to_broadcast->getDimensions());
45+
return output;
46+
}
47+
1348
auto layer_norm_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern({
1449
R"SIG(aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? gamma, Tensor? beta,
1550
float eps, bool cudnn_enabled) -> (Tensor))SIG",
1651
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
17-
auto input = args[0].ITensor(); // assumes non-static input Tensor
18-
auto orig_shape = input->getDimensions();
19-
auto shape = util::toVec(orig_shape);
20-
21-
/* Layer_Norm normalizes over last N dimensions.
22-
normalizaed_shape could be (C,H,W), (H,W), or (W). */
23-
// This could be an IntList or ITensorList. We only need the size of this list.
24-
auto normalized_shape = args[1].IValue()->toList();
25-
26-
// Unwrap eps.
27-
auto eps = args[4].unwrapToDouble();
28-
29-
LOG_DEBUG("cudnn disregarded");
30-
31-
// Set up axis_ask for E[x].
32-
uint32_t axis_mask = 0;
33-
for (size_t i = 0; i < normalized_shape.size(); i++) {
34-
axis_mask |= 1 << (shape.size() - i - 1);
52+
auto input = args[0].ITensorOrFreeze(ctx);
53+
auto input_shape = input->getDimensions();
54+
auto input_shape_vec = util::toVec(input_shape);
55+
auto normalized_shape = args[1].unwrapToIntList();
56+
auto normalized_shape_vec = util::toVec(util::toDims(normalized_shape));
57+
auto axis = input_shape_vec.size() - normalized_shape_vec.size();
58+
uint32_t axes_mask = 0;
59+
for (size_t i = axis; i < input_shape_vec.size(); i++) {
60+
axes_mask |= 1 << i;
3561
}
36-
LOG_DEBUG("Axis Mask for E[x]" << std::bitset<32>(axis_mask));
37-
38-
// E[x]
39-
auto mean_expected = ctx->net->addReduce(*input, nvinfer1::ReduceOperation::kAVG, axis_mask, true);
40-
TORCHTRT_CHECK(mean_expected, "Unable to create mean_expected from node: " << *n);
41-
mean_expected->setName((util::node_info(n) + "_mean_expected").c_str());
42-
auto mean_expected_out = mean_expected->getOutput(0);
43-
44-
// X-E[x]
45-
auto sub = add_elementwise(
46-
ctx, nvinfer1::ElementWiseOperation::kSUB, input, mean_expected_out, (util::node_info(n) + "_sub").c_str());
47-
TORCHTRT_CHECK(sub, "Unable to create Sub layer from node: " << *n);
48-
sub->setName((util::node_info(n) + "_sub").c_str());
49-
auto xsubmean_out = sub->getOutput(0);
50-
51-
// Variance = mean(pow(xsubmean,2))
52-
float pow_scalar = 2;
53-
auto exponent = tensor_to_const(ctx, torch::tensor({pow_scalar}));
54-
auto pow = add_elementwise(
55-
ctx, nvinfer1::ElementWiseOperation::kPOW, xsubmean_out, exponent, (util::node_info(n) + "_pow").c_str());
56-
TORCHTRT_CHECK(pow, "Unable to create Pow layer from node: " << *n);
57-
pow->setName((util::node_info(n) + "_pow").c_str());
58-
auto pow_out = pow->getOutput(0);
59-
60-
auto mean_var = ctx->net->addReduce(*pow_out, nvinfer1::ReduceOperation::kAVG, axis_mask, true);
61-
TORCHTRT_CHECK(mean_var, "Unable to create mean_var from node: " << *n);
62-
mean_var->setName((util::node_info(n) + "_mean_var").c_str());
63-
auto mean_var_out = mean_var->getOutput(0);
64-
65-
// Variance + eps
66-
auto eps_tensor = tensor_to_const(ctx, torch::tensor({eps}));
67-
auto add = add_elementwise(
68-
ctx, nvinfer1::ElementWiseOperation::kSUM, mean_var_out, eps_tensor, (util::node_info(n) + "_add").c_str());
69-
TORCHTRT_CHECK(add, "Unable to create Add layer from node: " << *n);
70-
add->setName((util::node_info(n) + "_add").c_str());
71-
auto add_out = add->getOutput(0);
7262

73-
// SQRT((Var + eps))
74-
auto sqrt = ctx->net->addUnary(*add_out, nvinfer1::UnaryOperation::kSQRT);
75-
TORCHTRT_CHECK(sqrt, "Unable to create unary(sqrt) from node: " << *n);
76-
sqrt->setName((util::node_info(n) + "_sqrt").c_str());
77-
auto sqrt_out = sqrt->getOutput(0);
78-
79-
// (x - E[x]) / sqrt((var + eps))
80-
auto div = add_elementwise(
81-
ctx, nvinfer1::ElementWiseOperation::kDIV, xsubmean_out, sqrt_out, (util::node_info(n) + "_div").c_str());
82-
TORCHTRT_CHECK(div, "Unable to create div layer from node: " << *n);
83-
div->setName((util::node_info(n) + "_div").c_str());
84-
auto div_out = div->getOutput(0);
85-
86-
if (!args[2].IValue()->isTensor() && !args[3].IValue()->isTensor()) {
87-
ctx->AssociateValueAndTensor(n->outputs()[0], div_out);
88-
return true;
89-
}
90-
91-
// Remove batch dimension from input shape for expand_size, which will
92-
// be used to create weights for addScaleNd later.
93-
auto expand_size = shape;
94-
expand_size.erase(expand_size.begin(), expand_size.begin() + 1);
95-
96-
// Set up gamma_weights and beta_weights from gamma_expand and
97-
// beta_expand.
98-
auto gamma_weights = Weights(ctx, at::ones(expand_size));
99-
auto beta_weights = Weights(ctx, at::zeros(expand_size));
100-
101-
if (args[2].IValue()->isTensor()) {
102-
torch::Tensor gamma;
103-
gamma = args[2].unwrapToTensor();
104-
auto gamma_expand = gamma.expand(expand_size);
105-
gamma_weights = Weights(ctx, gamma_expand);
63+
nvinfer1::ITensor* gamma = nullptr;
64+
if (args[2].IValue()->isNone()) {
65+
auto gamma_torch_tensor = torch::ones(input_shape_vec, torch::TensorOptions().dtype(torch::kFloat32));
66+
gamma = tensor_to_const(ctx, gamma_torch_tensor);
10667
} else {
107-
gamma_weights = Weights(ctx, at::ones(expand_size));
68+
gamma = args[2].ITensorOrFreeze(ctx);
69+
gamma = broadcast(ctx, n, gamma, input_shape_vec.size(), "gamma");
10870
}
10971

110-
if (args[3].IValue()->isTensor()) {
111-
torch::Tensor beta;
112-
beta = args[3].unwrapToTensor();
113-
auto beta_expand = beta.expand(expand_size);
114-
beta_weights = Weights(ctx, beta_expand);
72+
nvinfer1::ITensor* beta = nullptr;
73+
if (args[3].IValue()->isNone()) {
74+
auto beta_torch_tensor = torch::zeros(input_shape_vec, torch::TensorOptions().dtype(torch::kFloat32));
75+
beta = tensor_to_const(ctx, beta_torch_tensor);
11576
} else {
116-
beta_weights = Weights(ctx, at::zeros(expand_size));
77+
beta = args[3].ITensorOrFreeze(ctx);
78+
beta = broadcast(ctx, n, beta, input_shape_vec.size(), "beta");
11779
}
11880

119-
auto power = Weights(ctx, at::ones(expand_size));
120-
121-
auto gamma_tensor = ctx->net->addConstant(gamma_weights.shape, gamma_weights.data)->getOutput(0);
122-
auto scale_l = add_elementwise(
123-
ctx, nvinfer1::ElementWiseOperation::kPROD, div_out, gamma_tensor, (util::node_info(n) + "_scale").c_str());
124-
125-
auto beta_tensor = ctx->net->addConstant(beta_weights.shape, beta_weights.data)->getOutput(0);
126-
auto shift_l = add_elementwise(
127-
ctx,
128-
nvinfer1::ElementWiseOperation::kSUM,
129-
scale_l->getOutput(0),
130-
beta_tensor,
131-
(util::node_info(n) + "_shift").c_str());
132-
133-
auto power_tensor = ctx->net->addConstant(power.shape, power.data)->getOutput(0);
134-
auto power_l = add_elementwise(
135-
ctx,
136-
nvinfer1::ElementWiseOperation::kPOW,
137-
shift_l->getOutput(0),
138-
power_tensor,
139-
(util::node_info(n) + "_power").c_str());
81+
auto eps = args[4].unwrapToDouble();
14082

141-
power_l->setName((util::node_info(n) + "_scale_nd").c_str());
142-
auto power_l_out = power_l->getOutput(0);
83+
auto normalize_layer = ctx->net->addNormalization(*input, *gamma, *beta, axes_mask);
84+
TORCHTRT_CHECK(normalize_layer, "Unable to create layer_norm from node: " << *n);
85+
normalize_layer->setName(util::node_info(n).c_str());
86+
normalize_layer->setEpsilon(eps);
87+
normalize_layer->setComputePrecision(nvinfer1::DataType::kFLOAT);
88+
auto normalized = normalize_layer->getOutput(0);
14389

144-
ctx->AssociateValueAndTensor(n->outputs()[0], power_l_out);
90+
ctx->AssociateValueAndTensor(n->outputs()[0], normalized);
14591
return true;
14692
}});
14793

tests/core/conversion/converters/test_layer_norm.cpp

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@ TEST(Converters, ATenLayerNormConvertsCorrectlyLast3DimsNoGammaBeta) {
2929
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
3030
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
3131

32-
ASSERT_TRUE(
33-
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
32+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
3433
}
3534

3635
TEST(Converters, ATenLayerNormConvertsCorrectlyLast3Dims) {
@@ -60,8 +59,7 @@ TEST(Converters, ATenLayerNormConvertsCorrectlyLast3Dims) {
6059
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {gamma, beta});
6160
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
6261

63-
ASSERT_TRUE(
64-
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
62+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
6563
}
6664

6765
TEST(Converters, ATenLayerNormConvertsCorrectlyLast2Dims) {
@@ -90,8 +88,7 @@ TEST(Converters, ATenLayerNormConvertsCorrectlyLast2Dims) {
9088
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {gamma, beta});
9189
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
9290

93-
ASSERT_TRUE(
94-
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
91+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
9592
}
9693

9794
TEST(Converters, ATenLayerNormConvertsCorrectlyLast1Dims) {
@@ -119,8 +116,7 @@ TEST(Converters, ATenLayerNormConvertsCorrectlyLast1Dims) {
119116
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {gamma, beta});
120117
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
121118

122-
ASSERT_TRUE(
123-
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
119+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
124120
}
125121

126122
TEST(Converters, ATenLayerNormConvertsCorrectly3dInput1dNormalizedShape) {
@@ -134,7 +130,6 @@ TEST(Converters, ATenLayerNormConvertsCorrectly3dInput1dNormalizedShape) {
134130
%8 : float = prim::Constant[value=1.0000000000000001e-05]()
135131
%9 : Tensor = aten::layer_norm(%0, %4, %gamma, %beta, %8, %7)
136132
return (%9))IR";
137-
138133
auto g = std::make_shared<torch::jit::Graph>();
139134
torch::jit::parseIR(graph, g.get());
140135

@@ -148,6 +143,5 @@ TEST(Converters, ATenLayerNormConvertsCorrectly3dInput1dNormalizedShape) {
148143
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {gamma, beta});
149144
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
150145

151-
ASSERT_TRUE(
152-
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
146+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
153147
}

0 commit comments

Comments
 (0)