Commit 77949e0
authored
[RMSNorm] Fix JIT recompilation by removing tl.constexpr on rows_per_program & Cleanup Block kernel interface (#988)
## Summary
This PR optimizes the JIT compilation behavior for
`_rms_norm_backward_kernel` and cleans up the interface for
`_block_rms_norm_backward_kernel`.
1. Avoid JIT Recompilation: Removes `tl.constexpr` from the
`rows_per_program` argument in `_rms_norm_backward_kernel`.
2. Interface Cleanup: Removes the unused `rows_per_program` argument
from `_block_rms_norm_backward_kernel`.
## Details
1. Fix for Dynamic Shapes in `_rms_norm_backward_kernel`. Currently,
`rows_per_program` is marked as `tl.constexpr`, but it is used within a
standard dynamic `range` loop (not `tl.static_range`).
* Issue: The `tl.constexpr` hint provides **no loop unrolling benefits**
in this context because the loop bounds are determined at runtime
(dependent on `n_rows` and `program_id`). However, Triton still treats
the parameter as part of the kernel signature.
* Impact: In dynamic shape scenarios (where `rows_per_program` changes
with input size), this unnecessarily triggers JIT recompilation for
every new shape, causing severe cache thrashing and CPU overhead without
any performance gain.
* Fix: Removing `tl.constexpr` allows the compiled kernel to be reused
across different `rows_per_program` values.
2. Cleanup in `_block_rms_norm_backward_kernel`. The `rows_per_program`
argument was unused in the block-wise implementation. It has been
removed to avoid signature pollution and confusion.
## Testing Done
Verified that the changes do not introduce performance regressions. The
benchmark shows stable latency across different hidden sizes.
**Performance Benchmark**:
| Hidden Size | Latency (ms) | P50 (ms) |
|-------------|-----------------|-----------|
| 1024.00 | 0.13 | 0.11 |
| 2048.00 | 0.12 | 0.12 |
| 4096.00 | 0.12 | 0.12 |
| 8192.00 | 0.12 | 0.11 |
| 16384.00 | 0.18 | 0.18 |
| 32768.00 | 1.37 | 1.39 |
- Hardware Type: NVIDIA A100-SXM4-80GB
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence1 parent 44c2c31 commit 77949e0
1 file changed
+1
-3
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
130 | 130 | | |
131 | 131 | | |
132 | 132 | | |
133 | | - | |
| 133 | + | |
134 | 134 | | |
135 | 135 | | |
136 | 136 | | |
| |||
293 | 293 | | |
294 | 294 | | |
295 | 295 | | |
296 | | - | |
297 | 296 | | |
298 | 297 | | |
299 | 298 | | |
| |||
517 | 516 | | |
518 | 517 | | |
519 | 518 | | |
520 | | - | |
521 | 519 | | |
522 | 520 | | |
523 | 521 | | |
| |||
0 commit comments