Skip to content

Conversation

@naomiOvad
Copy link

@naomiOvad naomiOvad commented Nov 19, 2025

Description

This PR adds full and spec-compliant broadcasting support to both LayerNormalization and RMSNormalization.

Previously, onnxruntime supported only a partial set of broadcasting cases (based on the logic introduced in this PR #23297 ).
That implementation handled several cases but did not cover all valid broadcasting scenarios.

This PR introduces a complete generic broadcasting path, following the ONNX specification rules.
The previous implementation is preserved as a fast-path and is still used whenever the Scale/Bias shapes match directly.

Main changes:

  • Extended broadcasting logic in:
    layer_norm_helper.h
    layer_norm_impl.cc

  • Added full support for all valid broadcasting configurations of Scale and Bias.

  • Preserved previous partial logic as a fast-path for exact-match cases.

  • Added comprehensive tests to:
    layer_norm_op_test.cc
    rms_norm_op_test.cc

Motivation and Context

Before this fix, some valid ONNX broadcasting shapes were rejected in LayerNormalization and RMSNormalization.
This PR brings the operators into full alignment with the ONNX specification and fixes models that previously failed due to incomplete broadcasting support.

Fixes #26432
Fixes #18184

@naomiOvad naomiOvad marked this pull request as ready for review November 19, 2025 19:01
@naomiOvad
Copy link
Author

cc @justinchuby, @amarin16, @tianleiwu Could you please review my PR?
Thanks!

@tianleiwu
Copy link
Contributor

tianleiwu commented Nov 20, 2025

Please fix build errors in CI pipelines.

Overall logic aligns with ONNX broadcasting rules. Preserves optimized "fast path" for standard shapes; no regression for existing models.

Below are AI analysis:

Here are recommendations to improve the code quality, performance, and testing of your implementation.

1. Performance: Optimize the Generic Inner Loop (Critical)

The Issue:
The current ComputeJobGeneric implementation updates a multi-dimensional index array (idx) and calculates offsets in the innermost loop for every single element.

// Current Inner Loop
for (int64_t h = 0; h < norm_size; ++h) {
  // ... calculate offsets ...
  // ... math ...
  // ... heavy logic to update idx array ...
}

This logic prevents the compiler from auto-vectorizing the normalization math (SIMD), which is crucial for performance, especially for MLFloat16.

Recommendation:
Hoist the iteration of the innermost dimension out of the generic index tracking logic. Since the input X is always contiguous in the last dimension, and Scale/Bias will have a constant stride (either 0 or 1) for that dimension, exposing this allows the compiler to generate efficient vector code.

Proposed Change (Concept):
Modify ComputeJobGeneric to iterate in chunks of the last dimension.

// In ComputeJobGeneric:

// 1. Identify properties of the innermost dimension
int64_t last_dim_size = params.x_inner_dims.back();
int64_t sc_last_stride = params.sc_inner_inc.back();
int64_t bi_last_stride = has_bias ? params.bi_inner_inc.back() : 0; // Check has_bias first

// 2. Outer loop iterates over the "counts" of chunks
// Total iterations = norm_size / last_dim_size
int64_t num_chunks = norm_size / last_dim_size;

for (int64_t c = 0; c < num_chunks; ++c) {
    // Calculate base offsets for the start of this chunk using the generic logic
    int64_t off_sc = off_sc_row;
    int64_t off_bi = off_bi_row;
    for (size_t d = 0; d < static_cast<size_t>(last_rank) - 1; ++d) { // Note: -1 to skip last dim
        off_sc += idx[d] * sc_inner_inc[d];
        if (bias_data) off_bi += idx[d] * bi_inner_inc[d];
    }

    // 3. Inner loop: Tight, vectorized loop over the last dimension
    // The compiler can now verify that strides are constant (0 or 1) and vectorize this.
    int64_t base_h = c * last_dim_size;
    for (int64_t i = 0; i < last_dim_size; ++i) {
        int64_t h = base_h + i;
        // Use simple stride addition
        T s = scale_data[off_sc + i * sc_last_stride];
        // ... Perform Math ...
    }

    // 4. Update the multi-dimensional index 'idx' for the *next* chunk
    // Iterate backwards from the second-to-last dimension
    for (int64_t d = last_rank - 2; d >= 0; --d) {
        if (++idx[static_cast<size_t>(d)] < params.x_inner_dims[static_cast<size_t>(d)]) break;
        idx[static_cast<size_t>(d)] = 0;
    }
}

Note: This requires handling the edge case where last_rank == 0 or 1 separately or ensuring last_rank - 1 logic holds.

2. Code Quality: Reduce Duplication via Templates

The Issue:
You have two copies of ComputeJobGeneric: one for T (float/double) and one for MLFloat16. The broadcasting logic and loop structure are identical; only the load/store/cast mechanics differ.

Recommendation:
Refactor the broadcasting logic into a single function that accepts a Calculator Policy or a Functor.

// Define a functor for Float/Double
template <typename T>
struct NormalizationMath {
    static T Load(const T* ptr, int64_t offset) { return ptr[offset]; }
    static void Store(T* ptr, int64_t offset, float val) { ptr[offset] = static_cast<T>(val); }
    // ... other helpers
};

// Define a functor for MLFloat16
struct HalfMath {
    static float Load(const MLFloat16* ptr, int64_t offset) { return static_cast<float>(ptr[offset]); }
    static void Store(MLFloat16* ptr, int64_t offset, float val) { ptr[offset] = MLFloat16(val); }
};

// Single Generic Implementation
template <typename DataT, typename MathPolicy>
void ComputeJobGenericShared(...) {
    // ... unified broadcasting logic ...
    // float s = MathPolicy::Load(scale_data, off_sc);
    // ...
    // MathPolicy::Store(Y_data, h, result);
}

This reduces code size and ensures that fixes to the complex broadcasting logic (like the optimization in point #1) automatically apply to all data types.

3. Validation: Refine CheckInputs Constraints

The Issue:
The constraint params.norm_size <= 1 returns an error. While mathematically valid (variance is 0), standard implementations often handle scalar normalization (returning 0) rather than erroring, or norm_size might technically be 1 if axis is the last dim and dim size is 1.

Recommendation:
Ensure that norm_size == 1 behaves gracefully. If the input is [Batch, 1], normalization usually results in zeros (since mean=value, value-mean=0).
Also, in LayerNormHelper::CheckInputs:

// While strictly correct for variance calc, consider if this check is too aggressive for edge cases 
// where dimensions might be dynamically 1.
if (params.norm_size < 1) { ... } // Change <= to < ?

Note: If your specific kernel requires variance calculation that divides by norm_size, 1 is fine. If it divides by norm_size - 1 (unbiased), 1 is bad.

4. Testing: Add "Mixed" Broadcasting Cases

The Issue:
The current tests cover:

  1. No broadcasting.
  2. Outer broadcasting (Batch broadcast).
  3. Inner broadcasting (scalar scale).

Recommendation:
Add a test case for Mixed/Strided Broadcasting to ensure the generic loop truly handles non-contiguous strides correctly.

New Test Case Example:

  • Input: [1, 2, 4] (Batch=1, Seq=2, Feature=4)
  • Axis: 1 (Normalize over [2, 4])
  • Scale: [1, 4] (Broadcast over Seq dim 2).
    • Here, Scale has shape [1, 4] matching the inner dims [2, 4] via broadcasting 1 -> 2.
    • The stride for Scale on the Seq dimension is 0, but for Feature dimension is 1.
    • This exercises the specific logic where one inner stride is 0 and another is 1.
TEST(LayerNormTest, LayerNorm_Scale_Broadcast_Inner_Mixed) {
  OpTester test("LayerNormalization", 17);
  test.AddAttribute<float>("epsilon", 1e-05f);
  test.AddAttribute<int64_t>("axis", 1); // Normalize over last 2 dims (2, 4)

  // Input: [1, 2, 4]
  std::vector<int64_t> dims{1, 2, 4};
  std::vector<float> x = {
      0, 1, 2, 3,   // Row 0
      4, 5, 6, 7    // Row 1
  }; 
  test.AddInput<float>("X", dims, x);

  // Scale: [1, 4]. Broadcasts to [2, 4]
  // Stride on dim '2' is 0. Stride on dim '4' is 1.
  std::vector<float> scale = {1.0f, 0.5f, 1.0f, 0.5f}; 
  test.AddInput<float>("Scale", {1, 4}, scale); 

  // Expected output calculation required...
  test.Run();
}

5. Pre-calculation Optimization

The Issue:
Inside CheckInputs, you calculate sc_outer_inc, bi_outer_inc etc.

params.sc_outer_inc.push_back(params.sc_strides[static_cast<size_t>(i)]);

This allocates memory (push_back) every time Compute is called (since LayerNormParams is stack-allocated per Compute).

Recommendation:
Since LayerNormParams is recreated every call, avoid InlinedVector::push_back inside the loop if possible, or reserve size upfront.
More importantly, CheckInputs runs on every inference.

  • Ensure InlinedVector capacity (8) is sufficient for common models (it usually is).
  • Consider calculating sc_inner_inc and bi_inner_inc only if use_generic_broadcast becomes true, or lazily. Currently, you calculate them always, even for the fast path which doesn't use them.

Move heavy setup inside the conditional:

// In LayerNormHelper::CheckInputs
// ... detect if fast path is possible ...

if (params.broadcast_param == kLayerNormInvalidInput || outer_dep) {
   params.use_generic_broadcast = true;
   // ONLY compute expensive generic strides (sc_inner_inc, etc.) HERE
   // This saves overhead for the 99% case (BERT/Transformer fast path).
}

@naomiOvad
Copy link
Author

@tianleiwu Thanks for the detailed review!

Regarding the CI failures — they seem unrelated to this PR.
The failing pipelines are due to transient infrastructure issues (GitHub 503
errors when downloading googletest). Once GitHub is back to normal, a rerun
should resolve them.

I’ve applied all 5 suggested improvements:

  1. optimized the generic inner loop,
  2. unified the generic implementation using MathPolicy,
  3. refined norm_size validation,
  4. added the mixed-broadcast test,
  5. moved generic stride setup inside the generic-path branch.

Updated code pushed. Please let me know if anything else needs adjustment.

@naomiOvad
Copy link
Author

@tianleiwu I reviewed the CI failures and it looks like the 4 failing CI jobs are unrelated to the changes here.
The macOS job fails in vcpkg (ranlib I/O error), Windows CUDA has a generic build failure without compiler errors, and the QNN jobs fail due to unsupported ops/layouts inside the QNN EP.
Let me know if there’s anything specific I should adjust. Thanks!

@tianleiwu
Copy link
Contributor

tianleiwu commented Nov 24, 2025

The latest commit is a massive improvement. It elegantly addresses the performance concerns by vectorizing the inner loop and significantly improves code maintainability via the NormalizationMath policy functors. The generic path is now both robust and performant.

I have just one critical safety finding regarding input validation that could cause a crash (segfault) if not addressed.

1. Critical Safety Fix: Potential Segfault in Helper

In layer_norm_helper.h, the logic assumes the rank of Scale (and Bias) is less than or equal to the rank of X. While the ONNX spec implies this for unidirectional broadcasting, passing a malformed model where Scale.ndim > X.ndim will currently cause a segfault (buffer overflow) instead of a graceful error.

The Issue:

// layer_norm_helper.h
const size_t xr = x_shape.NumDimensions();
const size_t sr = scale_shape.NumDimensions();

// ...

// If sr > xr, then (xr - 1 - i) wraps around to SIZE_MAX when i >= xr
for (size_t i = 0; i < sr; ++i) {
  params.sc_dims[xr - 1 - i] = scale_shape.GetDims()[sr - 1 - i]; 
}

Recommended Fix:
Add a check at the top of CheckInputs to validate ranks.

// layer_norm_helper.h -> CheckInputs

const size_t xr = x_shape.NumDimensions();
const size_t sr = scale_shape.NumDimensions();
const size_t br = has_bias ? bias_shape.NumDimensions() : 0;

// ADD THIS CHECK:
if (sr > xr || (has_bias && br > xr)) {
  return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, 
                         kLayerNormInputShapeMismatchError,
                         " Scale/Bias rank cannot exceed Input rank.");
}

params.x_dims.clear();
// ...

@naomiOvad
Copy link
Author

The latest commit is a massive improvement. It elegantly addresses the performance concerns by vectorizing the inner loop and significantly improves code maintainability via the NormalizationMath policy functors. The generic path is now both robust and performant.

I have just one critical safety finding regarding input validation that could cause a crash (segfault) if not addressed.

1. Critical Safety Fix: Potential Segfault in Helper

In layer_norm_helper.h, the logic assumes the rank of Scale (and Bias) is less than or equal to the rank of X. While the ONNX spec implies this for unidirectional broadcasting, passing a malformed model where Scale.ndim > X.ndim will currently cause a segfault (buffer overflow) instead of a graceful error.

The Issue:

// layer_norm_helper.h
const size_t xr = x_shape.NumDimensions();
const size_t sr = scale_shape.NumDimensions();

// ...

// If sr > xr, then (xr - 1 - i) wraps around to SIZE_MAX when i >= xr
for (size_t i = 0; i < sr; ++i) {
  params.sc_dims[xr - 1 - i] = scale_shape.GetDims()[sr - 1 - i]; 
}

Recommended Fix: Add a check at the top of CheckInputs to validate ranks.

// layer_norm_helper.h -> CheckInputs

const size_t xr = x_shape.NumDimensions();
const size_t sr = scale_shape.NumDimensions();
const size_t br = has_bias ? bias_shape.NumDimensions() : 0;

// ADD THIS CHECK:
if (sr > xr || (has_bias && br > xr)) {
  return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, 
                         kLayerNormInputShapeMismatchError,
                         " Scale/Bias rank cannot exceed Input rank.");
}

params.x_dims.clear();
// ...

Added the rank validation to CheckInputs as suggested. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

2 participants