Commit 7d54ff8
authored
fixbug, ensure FP32 accumulation for dW in Llama-mode RMSNorm backward (#950)
## Summary
This PR fixes a precision issue in `_block_rms_norm_backward_kernel`
when running in `_CASTING_MODE_LLAMA`.
It enforces FP32 accumulation for the weight gradient (`dW`), aligning
its behavior with the existing `_rms_norm_backward_kernel`.
* In `_rms_norm_backward_kernel` (Row-wise): The gradient `dW_row` is
initialized as `float32`. When iterating through elements, the term
`dY_row * (X_row * rstd_row).to(X_dtype)` (which is `bfloat16`) is added
to `dW_row`. This operation performs an implicit cast to FP32 during the
addition `dW_row += val`, effectively accumulating in high precision.
* In `_block_rms_norm_backward_kernel`(Block-wise - The Bug): The code
uses `tl.sum` for reduction:
```Python
# dY_row * (...) is bfloat16
dW_row += tl.sum((dY_row * (X_row * rstd_row[:, None]).to(X_dtype)), 0)
```
Here, `tl.sum` receives a `bfloat16` tensor. Consequently, the reduction
itself is performed in `bfloat16`. The precision loss occurs inside the
reduction before the result is added to the FP32 `dW_row`. This leads to
significant numerical errors for both small and large shapes due to the
limited mantissa of `bfloat16`.
## Testing Done
- Hardware Type: A100-SXM4-80GB
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence
---------1 parent 0a62700 commit 7d54ff8
1 file changed
+2
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
349 | 349 | | |
350 | 350 | | |
351 | 351 | | |
352 | | - | |
| 352 | + | |
| 353 | + | |
353 | 354 | | |
354 | 355 | | |
355 | 356 | | |
| |||
0 commit comments