Skip to content

Commit 77949e0

Browse files
authored
[RMSNorm] Fix JIT recompilation by removing tl.constexpr on rows_per_program & Cleanup Block kernel interface (#988)
## Summary This PR optimizes the JIT compilation behavior for `_rms_norm_backward_kernel` and cleans up the interface for `_block_rms_norm_backward_kernel`. 1. Avoid JIT Recompilation: Removes `tl.constexpr` from the `rows_per_program` argument in `_rms_norm_backward_kernel`. 2. Interface Cleanup: Removes the unused `rows_per_program` argument from `_block_rms_norm_backward_kernel`. ## Details 1. Fix for Dynamic Shapes in `_rms_norm_backward_kernel`. Currently, `rows_per_program` is marked as `tl.constexpr`, but it is used within a standard dynamic `range` loop (not `tl.static_range`). * Issue: The `tl.constexpr` hint provides **no loop unrolling benefits** in this context because the loop bounds are determined at runtime (dependent on `n_rows` and `program_id`). However, Triton still treats the parameter as part of the kernel signature. * Impact: In dynamic shape scenarios (where `rows_per_program` changes with input size), this unnecessarily triggers JIT recompilation for every new shape, causing severe cache thrashing and CPU overhead without any performance gain. * Fix: Removing `tl.constexpr` allows the compiled kernel to be reused across different `rows_per_program` values. 2. Cleanup in `_block_rms_norm_backward_kernel`. The `rows_per_program` argument was unused in the block-wise implementation. It has been removed to avoid signature pollution and confusion. ## Testing Done Verified that the changes do not introduce performance regressions. The benchmark shows stable latency across different hidden sizes. **Performance Benchmark**: | Hidden Size | Latency (ms) | P50 (ms) | |-------------|-----------------|-----------| | 1024.00 | 0.13 | 0.11 | | 2048.00 | 0.12 | 0.12 | | 4096.00 | 0.12 | 0.12 | | 8192.00 | 0.12 | 0.11 | | 16384.00 | 0.18 | 0.18 | | 32768.00 | 1.37 | 1.39 | - Hardware Type: NVIDIA A100-SXM4-80GB - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence
1 parent 44c2c31 commit 77949e0

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

src/liger_kernel/ops/rms_norm.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def _rms_norm_backward_kernel(
130130
n_rows,
131131
n_cols,
132132
offset,
133-
rows_per_program: tl.constexpr,
133+
rows_per_program,
134134
casting_mode: tl.constexpr,
135135
BLOCK_SIZE: tl.constexpr,
136136
):
@@ -293,7 +293,6 @@ def _block_rms_norm_backward_kernel(
293293
n_rows,
294294
n_cols,
295295
offset,
296-
rows_per_program: tl.constexpr,
297296
casting_mode: tl.constexpr,
298297
BLOCK_SIZE: tl.constexpr,
299298
BLOCK_ROW: tl.constexpr,
@@ -517,7 +516,6 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
517516
n_rows,
518517
n_cols,
519518
offset,
520-
rows_per_program,
521519
casting_mode,
522520
BLOCK_SIZE=BLOCK_SIZE,
523521
num_warps=num_warps,

0 commit comments

Comments
 (0)