Skip to content

Commit b27b4c5

Browse files
vvvdwbvvvlancerts
andauthored
Fix illegal memory access in Triton RMSNorm and RoPE (#804)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> When using very large tensors (e.g. seq_len=1e6, hidden_size=4096), Triton’s default 32-bit `tl.program_id(0)` can overflow, leading to out-of-bounds memory accesses. This change casts the program ID to 64-bit (`tl.int64`) to ensure all pointer arithmetic stays within the valid address range. Fix #803 <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Co-authored-by: Shao Tang <[email protected]>
1 parent 9cf2019 commit b27b4c5

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

src/liger_kernel/ops/rms_norm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def _rms_norm_forward_kernel(
6363
3. https://arxiv.org/pdf/1910.07467
6464
"""
6565

66-
row_idx = tl.program_id(0)
66+
row_idx = tl.program_id(0).to(tl.int64)
6767
col_offsets = tl.arange(0, BLOCK_SIZE)
6868
mask = col_offsets < n_cols
6969

@@ -137,7 +137,7 @@ def _rms_norm_backward_kernel(
137137
dw = sum(dy * (x / RMS)). summation over BxT dimension
138138
"""
139139

140-
row_block_id = tl.program_id(0)
140+
row_block_id = tl.program_id(0).to(tl.int64)
141141
row_start = row_block_id * rows_per_program
142142
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
143143
col_offsets = tl.arange(0, BLOCK_SIZE)

src/liger_kernel/ops/rope.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def _triton_rope(
3232

3333
# cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
3434
# stride: (seq_len * head_dim, head_dim, 1)
35-
pid = tl.program_id(0)
35+
pid = tl.program_id(0).to(tl.int64)
3636

3737
# locate start address
3838
q_ptr = q_ptr + pid * q_row_stride

0 commit comments

Comments
 (0)