Skip to content

Commit 7d54ff8

Browse files
authored
fixbug, ensure FP32 accumulation for dW in Llama-mode RMSNorm backward (#950)
## Summary This PR fixes a precision issue in `_block_rms_norm_backward_kernel` when running in `_CASTING_MODE_LLAMA`. It enforces FP32 accumulation for the weight gradient (`dW`), aligning its behavior with the existing `_rms_norm_backward_kernel`. * In `_rms_norm_backward_kernel` (Row-wise): The gradient `dW_row` is initialized as `float32`. When iterating through elements, the term `dY_row * (X_row * rstd_row).to(X_dtype)` (which is `bfloat16`) is added to `dW_row`. This operation performs an implicit cast to FP32 during the addition `dW_row += val`, effectively accumulating in high precision. * In `_block_rms_norm_backward_kernel`(Block-wise - The Bug): The code uses `tl.sum` for reduction: ```Python # dY_row * (...) is bfloat16 dW_row += tl.sum((dY_row * (X_row * rstd_row[:, None]).to(X_dtype)), 0) ``` Here, `tl.sum` receives a `bfloat16` tensor. Consequently, the reduction itself is performed in `bfloat16`. The precision loss occurs inside the reduction before the result is added to the FP32 `dW_row`. This leads to significant numerical errors for both small and large shapes due to the limited mantissa of `bfloat16`. ## Testing Done - Hardware Type: A100-SXM4-80GB - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence ---------
1 parent 0a62700 commit 7d54ff8

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/liger_kernel/ops/rms_norm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,8 @@ def _block_rms_norm_backward_kernel(
349349

350350
# calculate the gradient of W
351351
if casting_mode == _CASTING_MODE_LLAMA:
352-
dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]).to(X_dtype), 0)
352+
# TODO(tcc): use tl.sum(..., dtype=tl.float32) once we upgrade to triton>=3.3.0
353+
dW_row += tl.sum((dY_row * (X_row * rstd_row[:, None]).to(X_dtype)).to(tl.float32), 0)
353354
else:
354355
# here X_row is already in fp32 (see previous if block)
355356
dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)

0 commit comments

Comments
 (0)