Skip to content

Commit 5ec6d8c

Browse files
committed
Add ONNX opset 23 RMSNormalization operator support
Implements RMSNormalization operator for TensorRT ONNX parser, enabling deployment of modern transformer architectures (LLaMA, Mistral, etc.) that use RMSNorm instead of LayerNorm. Implementation details: - Computes Y = (X / sqrt(mean(X^2) + epsilon)) * scale - Supports FP32, FP16, and BF16 data types - Handles axis attribute for normalization dimensions - Supports epsilon and stash_type attributes per ONNX spec Changes: - onnxOpImporters.cpp: Add RMSNormalization importer using TensorRT primitive operations (ElementWise, Reduce, Unary) - onnxOpCheckers.cpp: Add empty checker for RMSNormalization - docs/operators.md: Add RMSNormalization to supported operators matrix - onnx_backend_test.py: Include RMSNormalization tests Fixes onnx/onnx-tensorrt#4639 (via NVIDIA/TensorRT#4639) Signed-off-by: Aditi_Pandey <[email protected]>
1 parent c727277 commit 5ec6d8c

File tree

4 files changed

+90
-0
lines changed

4 files changed

+90
-0
lines changed

docs/operators.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ TensorRT supports the following ONNX data types: DOUBLE, FLOAT32, FLOAT16, BFLOA
161161
| Reshape | Y | FP32, FP16, BF16, INT32, INT64, BOOL |
162162
| Resize | Y | FP32, FP16, BF16 | Supported resize transformation modes: `half_pixel`, `pytorch_half_pixel`, `tf_half_pixel_for_nn`, `asymmetric`, and `align_corners`.<br />Supported resize modes: `nearest`, `linear`.<br />Supported nearest modes: `floor`, `ceil`, `round_prefer_floor`, `round_prefer_ceil`.<br />Supported aspect ratio policy: `stretch`.<br />When `scales` is a tensor input, `axes` must be an iota vector of length rank(input).<br />Antialiasing is not supported.|
163163
| ReverseSequence | Y | FP32, FP16, BF16, INT32, INT64, BOOL |
164+
| RMSNormalization | Y | FP32, FP16, BF16 | Only the first output `Y` is supported. Introduced in opset 23.
164165
| RNN | Y | FP32, FP16, BF16| For bidirectional RNNs, activation functions must be the same for both the forward and reverse pass
165166
| RoiAlign | Y | FP32, FP16 |
166167
| Round | Y | FP32, FP16, BF16 |

onnxOpCheckers.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,8 @@ DEFINE_OP_EMPTY_CHECKER(ReduceSum)
599599

600600
DEFINE_OP_EMPTY_CHECKER(ReduceSumSquare)
601601

602+
DEFINE_OP_EMPTY_CHECKER(RMSNormalization)
603+
602604
DEFINE_OP_EMPTY_CHECKER(Relu)
603605

604606
DEFINE_OP_EMPTY_CHECKER(Sign)

onnxOpImporters.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4791,6 +4791,92 @@ DEFINE_BUILTIN_OP_IMPORTER(ReduceSumSquare)
47914791
inputs.size() >= 2 ? inputs.at(1) : TensorOrWeights());
47924792
}
47934793

4794+
// RMSNormalization: Y = (X / sqrt(mean(X^2) + epsilon)) * scale
4795+
// Introduced in ONNX opset 23
4796+
DEFINE_BUILTIN_OP_IMPORTER(RMSNormalization)
4797+
{
4798+
using eOp = nvinfer1::ElementWiseOperation;
4799+
using uOp = nvinfer1::UnaryOperation;
4800+
using rOp = nvinfer1::ReduceOperation;
4801+
4802+
// Get input tensor
4803+
nvinfer1::ITensor* input = &convertToTensor(inputs.at(0), ctx);
4804+
auto const nbDims = input->getDimensions().nbDims;
4805+
auto const dt = input->getType();
4806+
4807+
// Validate supported data types
4808+
ONNXTRT_CHECK_NODE((dt == DataType::kFLOAT || dt == DataType::kHALF || dt == DataType::kBF16),
4809+
"Only float32/float16/bfloat16 inputs/outputs supported in RMSNormalization. The current data type = "
4810+
+ getTrtDtypeName(dt) + ".",
4811+
node, nodeIdx, ErrorCode::kUNSUPPORTED_NODE_DATATYPE);
4812+
4813+
// Get scale tensor
4814+
nvinfer1::ITensor* scale = &convertToTensor(inputs.at(1), ctx);
4815+
4816+
// Parse attributes
4817+
OnnxAttrs attrs(node, ctx);
4818+
float const epsilon = attrs.get("epsilon", 1e-5f);
4819+
int32_t axis = attrs.get("axis", -1);
4820+
nvinfer1::DataType computeType = nvinfer1::DataType::kFLOAT;
4821+
convertDtype(attrs.get<int32_t>("stash_type", 1), &computeType);
4822+
4823+
// Convert negative axis to positive
4824+
convertAxis(axis, nbDims, node, nodeIdx);
4825+
4826+
// Create axes mask for normalization (from axis to end)
4827+
uint32_t axesMask = 0;
4828+
for (int32_t i = axis; i < nbDims; i++)
4829+
{
4830+
axesMask |= 1 << i;
4831+
}
4832+
4833+
// Step 1: Square the input (X^2)
4834+
auto* sqrLayer = N_CHECK(ctx->network()->addElementWise(*input, *input, eOp::kPROD));
4835+
ctx->registerLayer(sqrLayer, node);
4836+
auto* xSquared = N_CHECK(sqrLayer->getOutput(0));
4837+
4838+
// Step 2: Mean of squared values (mean(X^2))
4839+
auto* meanLayer = N_CHECK(ctx->network()->addReduce(*xSquared, rOp::kAVG, axesMask, true));
4840+
ctx->registerLayer(meanLayer, node);
4841+
auto* meanSquared = N_CHECK(meanLayer->getOutput(0));
4842+
4843+
// Step 3: Add epsilon (mean(X^2) + epsilon)
4844+
nvinfer1::IConstantLayer* epsilonLayer;
4845+
if (dt == DataType::kHALF)
4846+
{
4847+
epsilonLayer = addConstantScalar(ctx, static_cast<half_float::half>(epsilon), ::ONNX_NAMESPACE::TensorProto::FLOAT16);
4848+
}
4849+
else if (dt == DataType::kBF16)
4850+
{
4851+
epsilonLayer = addConstantScalar(ctx, static_cast<BFloat16>(epsilon), ::ONNX_NAMESPACE::TensorProto::BFLOAT16);
4852+
}
4853+
else
4854+
{
4855+
epsilonLayer = addConstantScalar(ctx, epsilon, ::ONNX_NAMESPACE::TensorProto::FLOAT);
4856+
}
4857+
auto* epsilonTensor = N_CHECK(epsilonLayer->getOutput(0));
4858+
auto* addEpsLayer = N_CHECK(ctx->network()->addElementWise(*meanSquared, *epsilonTensor, eOp::kSUM));
4859+
ctx->registerLayer(addEpsLayer, node);
4860+
auto* meanPlusEps = N_CHECK(addEpsLayer->getOutput(0));
4861+
4862+
// Step 4: Square root (sqrt(mean(X^2) + epsilon) = RMS)
4863+
auto* sqrtLayer = N_CHECK(ctx->network()->addUnary(*meanPlusEps, uOp::kSQRT));
4864+
ctx->registerLayer(sqrtLayer, node);
4865+
auto* rms = N_CHECK(sqrtLayer->getOutput(0));
4866+
4867+
// Step 5: Divide input by RMS (X / RMS = normalized)
4868+
auto* divLayer = N_CHECK(ctx->network()->addElementWise(*input, *rms, eOp::kDIV));
4869+
ctx->registerLayer(divLayer, node);
4870+
auto* normalized = N_CHECK(divLayer->getOutput(0));
4871+
4872+
// Step 6: Broadcast scale to input size and multiply (normalized * scale)
4873+
broadcastTensors(ctx, normalized, scale);
4874+
auto* scaleLayer = N_CHECK(ctx->network()->addElementWise(*normalized, *scale, eOp::kPROD));
4875+
ctx->registerLayer(scaleLayer, node);
4876+
4877+
RETURN_FIRST_OUTPUT(scaleLayer, node, nodeIdx);
4878+
}
4879+
47944880
DEFINE_BUILTIN_OP_IMPORTER(Relu)
47954881
{
47964882
return activationHelper(ctx, node, nodeIdx, inputs, nvinfer1::ActivationType::kRELU);

onnx_backend_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@
107107
backend_test.include(r'.*test_reduce.*')
108108
backend_test.include(r'.*test_ReLU*')
109109
backend_test.include(r'.*test_relu.*')
110+
backend_test.include(r'.*test_rms_normalization.*')
110111
backend_test.include(r'.*test_selu.*')
111112
backend_test.include(r'.*test_shape.*')
112113
backend_test.include(r'.*test_Sigmoid*')

0 commit comments

Comments
 (0)