diff --git a/docs/operators.md b/docs/operators.md index 39034cd..2b25fff 100644 --- a/docs/operators.md +++ b/docs/operators.md @@ -161,6 +161,7 @@ TensorRT supports the following ONNX data types: DOUBLE, FLOAT32, FLOAT16, BFLOA | Reshape | Y | FP32, FP16, BF16, INT32, INT64, BOOL | | Resize | Y | FP32, FP16, BF16 | Supported resize transformation modes: `half_pixel`, `pytorch_half_pixel`, `tf_half_pixel_for_nn`, `asymmetric`, and `align_corners`.
Supported resize modes: `nearest`, `linear`.
Supported nearest modes: `floor`, `ceil`, `round_prefer_floor`, `round_prefer_ceil`.
Supported aspect ratio policy: `stretch`.
When `scales` is a tensor input, `axes` must be an iota vector of length rank(input).
Antialiasing is not supported.| | ReverseSequence | Y | FP32, FP16, BF16, INT32, INT64, BOOL | +| RMSNormalization | Y | FP32, FP16, BF16 | Only the first output `Y` is supported. Introduced in opset 23. | RNN | Y | FP32, FP16, BF16| For bidirectional RNNs, activation functions must be the same for both the forward and reverse pass | RoiAlign | Y | FP32, FP16 | | Round | Y | FP32, FP16, BF16 | diff --git a/onnxOpCheckers.cpp b/onnxOpCheckers.cpp index d1dc696..7369d6e 100644 --- a/onnxOpCheckers.cpp +++ b/onnxOpCheckers.cpp @@ -599,6 +599,8 @@ DEFINE_OP_EMPTY_CHECKER(ReduceSum) DEFINE_OP_EMPTY_CHECKER(ReduceSumSquare) +DEFINE_OP_EMPTY_CHECKER(RMSNormalization) + DEFINE_OP_EMPTY_CHECKER(Relu) DEFINE_OP_EMPTY_CHECKER(Sign) diff --git a/onnxOpImporters.cpp b/onnxOpImporters.cpp index d3e644b..4cc7f16 100644 --- a/onnxOpImporters.cpp +++ b/onnxOpImporters.cpp @@ -4791,6 +4791,92 @@ DEFINE_BUILTIN_OP_IMPORTER(ReduceSumSquare) inputs.size() >= 2 ? inputs.at(1) : TensorOrWeights()); } +// RMSNormalization: Y = (X / sqrt(mean(X^2) + epsilon)) * scale +// Introduced in ONNX opset 23 +DEFINE_BUILTIN_OP_IMPORTER(RMSNormalization) +{ + using eOp = nvinfer1::ElementWiseOperation; + using uOp = nvinfer1::UnaryOperation; + using rOp = nvinfer1::ReduceOperation; + + // Get input tensor + nvinfer1::ITensor* input = &convertToTensor(inputs.at(0), ctx); + auto const nbDims = input->getDimensions().nbDims; + auto const dt = input->getType(); + + // Validate supported data types + ONNXTRT_CHECK_NODE((dt == DataType::kFLOAT || dt == DataType::kHALF || dt == DataType::kBF16), + "Only float32/float16/bfloat16 inputs/outputs supported in RMSNormalization. The current data type = " + + getTrtDtypeName(dt) + ".", + node, nodeIdx, ErrorCode::kUNSUPPORTED_NODE_DATATYPE); + + // Get scale tensor + nvinfer1::ITensor* scale = &convertToTensor(inputs.at(1), ctx); + + // Parse attributes + OnnxAttrs attrs(node, ctx); + float const epsilon = attrs.get("epsilon", 1e-5f); + int32_t axis = attrs.get("axis", -1); + nvinfer1::DataType computeType = nvinfer1::DataType::kFLOAT; + convertDtype(attrs.get("stash_type", 1), &computeType); + + // Convert negative axis to positive + convertAxis(axis, nbDims, node, nodeIdx); + + // Create axes mask for normalization (from axis to end) + uint32_t axesMask = 0; + for (int32_t i = axis; i < nbDims; i++) + { + axesMask |= 1 << i; + } + + // Step 1: Square the input (X^2) + auto* sqrLayer = N_CHECK(ctx->network()->addElementWise(*input, *input, eOp::kPROD)); + ctx->registerLayer(sqrLayer, node); + auto* xSquared = N_CHECK(sqrLayer->getOutput(0)); + + // Step 2: Mean of squared values (mean(X^2)) + auto* meanLayer = N_CHECK(ctx->network()->addReduce(*xSquared, rOp::kAVG, axesMask, true)); + ctx->registerLayer(meanLayer, node); + auto* meanSquared = N_CHECK(meanLayer->getOutput(0)); + + // Step 3: Add epsilon (mean(X^2) + epsilon) + nvinfer1::IConstantLayer* epsilonLayer; + if (dt == DataType::kHALF) + { + epsilonLayer = addConstantScalar(ctx, static_cast(epsilon), ::ONNX_NAMESPACE::TensorProto::FLOAT16); + } + else if (dt == DataType::kBF16) + { + epsilonLayer = addConstantScalar(ctx, static_cast(epsilon), ::ONNX_NAMESPACE::TensorProto::BFLOAT16); + } + else + { + epsilonLayer = addConstantScalar(ctx, epsilon, ::ONNX_NAMESPACE::TensorProto::FLOAT); + } + auto* epsilonTensor = N_CHECK(epsilonLayer->getOutput(0)); + auto* addEpsLayer = N_CHECK(ctx->network()->addElementWise(*meanSquared, *epsilonTensor, eOp::kSUM)); + ctx->registerLayer(addEpsLayer, node); + auto* meanPlusEps = N_CHECK(addEpsLayer->getOutput(0)); + + // Step 4: Square root (sqrt(mean(X^2) + epsilon) = RMS) + auto* sqrtLayer = N_CHECK(ctx->network()->addUnary(*meanPlusEps, uOp::kSQRT)); + ctx->registerLayer(sqrtLayer, node); + auto* rms = N_CHECK(sqrtLayer->getOutput(0)); + + // Step 5: Divide input by RMS (X / RMS = normalized) + auto* divLayer = N_CHECK(ctx->network()->addElementWise(*input, *rms, eOp::kDIV)); + ctx->registerLayer(divLayer, node); + auto* normalized = N_CHECK(divLayer->getOutput(0)); + + // Step 6: Broadcast scale to input size and multiply (normalized * scale) + broadcastTensors(ctx, normalized, scale); + auto* scaleLayer = N_CHECK(ctx->network()->addElementWise(*normalized, *scale, eOp::kPROD)); + ctx->registerLayer(scaleLayer, node); + + RETURN_FIRST_OUTPUT(scaleLayer, node, nodeIdx); +} + DEFINE_BUILTIN_OP_IMPORTER(Relu) { return activationHelper(ctx, node, nodeIdx, inputs, nvinfer1::ActivationType::kRELU); diff --git a/onnx_backend_test.py b/onnx_backend_test.py index f62004a..ecbf469 100644 --- a/onnx_backend_test.py +++ b/onnx_backend_test.py @@ -107,6 +107,7 @@ backend_test.include(r'.*test_reduce.*') backend_test.include(r'.*test_ReLU*') backend_test.include(r'.*test_relu.*') +backend_test.include(r'.*test_rms_normalization.*') backend_test.include(r'.*test_selu.*') backend_test.include(r'.*test_shape.*') backend_test.include(r'.*test_Sigmoid*')