-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Add full broadcasting support to LayerNormalization and RMSNormalization #26613
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
cc @justinchuby, @amarin16, @tianleiwu Could you please review my PR? |
|
Please fix build errors in CI pipelines. Overall logic aligns with ONNX broadcasting rules. Preserves optimized "fast path" for standard shapes; no regression for existing models. Below are AI analysis:Here are recommendations to improve the code quality, performance, and testing of your implementation. 1. Performance: Optimize the Generic Inner Loop (Critical)The Issue: // Current Inner Loop
for (int64_t h = 0; h < norm_size; ++h) {
// ... calculate offsets ...
// ... math ...
// ... heavy logic to update idx array ...
}This logic prevents the compiler from auto-vectorizing the normalization math (SIMD), which is crucial for performance, especially for Recommendation: Proposed Change (Concept): // In ComputeJobGeneric:
// 1. Identify properties of the innermost dimension
int64_t last_dim_size = params.x_inner_dims.back();
int64_t sc_last_stride = params.sc_inner_inc.back();
int64_t bi_last_stride = has_bias ? params.bi_inner_inc.back() : 0; // Check has_bias first
// 2. Outer loop iterates over the "counts" of chunks
// Total iterations = norm_size / last_dim_size
int64_t num_chunks = norm_size / last_dim_size;
for (int64_t c = 0; c < num_chunks; ++c) {
// Calculate base offsets for the start of this chunk using the generic logic
int64_t off_sc = off_sc_row;
int64_t off_bi = off_bi_row;
for (size_t d = 0; d < static_cast<size_t>(last_rank) - 1; ++d) { // Note: -1 to skip last dim
off_sc += idx[d] * sc_inner_inc[d];
if (bias_data) off_bi += idx[d] * bi_inner_inc[d];
}
// 3. Inner loop: Tight, vectorized loop over the last dimension
// The compiler can now verify that strides are constant (0 or 1) and vectorize this.
int64_t base_h = c * last_dim_size;
for (int64_t i = 0; i < last_dim_size; ++i) {
int64_t h = base_h + i;
// Use simple stride addition
T s = scale_data[off_sc + i * sc_last_stride];
// ... Perform Math ...
}
// 4. Update the multi-dimensional index 'idx' for the *next* chunk
// Iterate backwards from the second-to-last dimension
for (int64_t d = last_rank - 2; d >= 0; --d) {
if (++idx[static_cast<size_t>(d)] < params.x_inner_dims[static_cast<size_t>(d)]) break;
idx[static_cast<size_t>(d)] = 0;
}
}Note: This requires handling the edge case where 2. Code Quality: Reduce Duplication via TemplatesThe Issue: Recommendation: // Define a functor for Float/Double
template <typename T>
struct NormalizationMath {
static T Load(const T* ptr, int64_t offset) { return ptr[offset]; }
static void Store(T* ptr, int64_t offset, float val) { ptr[offset] = static_cast<T>(val); }
// ... other helpers
};
// Define a functor for MLFloat16
struct HalfMath {
static float Load(const MLFloat16* ptr, int64_t offset) { return static_cast<float>(ptr[offset]); }
static void Store(MLFloat16* ptr, int64_t offset, float val) { ptr[offset] = MLFloat16(val); }
};
// Single Generic Implementation
template <typename DataT, typename MathPolicy>
void ComputeJobGenericShared(...) {
// ... unified broadcasting logic ...
// float s = MathPolicy::Load(scale_data, off_sc);
// ...
// MathPolicy::Store(Y_data, h, result);
}This reduces code size and ensures that fixes to the complex broadcasting logic (like the optimization in point #1) automatically apply to all data types. 3. Validation: Refine
|
…d improve input validation
|
@tianleiwu Thanks for the detailed review! Regarding the CI failures — they seem unrelated to this PR. I’ve applied all 5 suggested improvements:
Updated code pushed. Please let me know if anything else needs adjustment. |
|
@tianleiwu I reviewed the CI failures and it looks like the 4 failing CI jobs are unrelated to the changes here. |
|
The latest commit is a massive improvement. It elegantly addresses the performance concerns by vectorizing the inner loop and significantly improves code maintainability via the I have just one critical safety finding regarding input validation that could cause a crash (segfault) if not addressed. 1. Critical Safety Fix: Potential Segfault in HelperIn The Issue: // layer_norm_helper.h
const size_t xr = x_shape.NumDimensions();
const size_t sr = scale_shape.NumDimensions();
// ...
// If sr > xr, then (xr - 1 - i) wraps around to SIZE_MAX when i >= xr
for (size_t i = 0; i < sr; ++i) {
params.sc_dims[xr - 1 - i] = scale_shape.GetDims()[sr - 1 - i];
}Recommended Fix: // layer_norm_helper.h -> CheckInputs
const size_t xr = x_shape.NumDimensions();
const size_t sr = scale_shape.NumDimensions();
const size_t br = has_bias ? bias_shape.NumDimensions() : 0;
// ADD THIS CHECK:
if (sr > xr || (has_bias && br > xr)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
kLayerNormInputShapeMismatchError,
" Scale/Bias rank cannot exceed Input rank.");
}
params.x_dims.clear();
// ... |
Added the rank validation to CheckInputs as suggested. Thanks! |
Description
This PR adds full and spec-compliant broadcasting support to both LayerNormalization and RMSNormalization.
Previously, onnxruntime supported only a partial set of broadcasting cases (based on the logic introduced in this PR #23297 ).
That implementation handled several cases but did not cover all valid broadcasting scenarios.
This PR introduces a complete generic broadcasting path, following the ONNX specification rules.
The previous implementation is preserved as a fast-path and is still used whenever the Scale/Bias shapes match directly.
Main changes:
Extended broadcasting logic in:
layer_norm_helper.h
layer_norm_impl.cc
Added full support for all valid broadcasting configurations of Scale and Bias.
Preserved previous partial logic as a fast-path for exact-match cases.
Added comprehensive tests to:
layer_norm_op_test.cc
rms_norm_op_test.cc
Motivation and Context
Before this fix, some valid ONNX broadcasting shapes were rejected in LayerNormalization and RMSNormalization.
This PR brings the operators into full alignment with the ONNX specification and fixes models that previously failed due to incomplete broadcasting support.
Fixes #26432
Fixes #18184