Skip to content

Commit f408175

Browse files
authored
avoid pointer mutation in add_rms_norm kernel (#1008)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> Rewrite fused_add_rms_norm kernel to use explicit channel offsets instead of mutating X/Y base pointers inside loops. This improves Triton compiler optimization opportunities, enables more predictable memory access patterns, and avoids loop-carried pointer dependencies. ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <img width="1800" height="382" alt="image" src="https://github.com/user-attachments/assets/b361d41e-1379-4835-8acb-b0234d5af22b" /> - Hardware Type: Ascend NPU 910B4 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence
1 parent 4dd540e commit f408175

File tree

1 file changed

+13
-20
lines changed

1 file changed

+13
-20
lines changed

src/liger_kernel/ops/fused_add_rms_norm.py

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -162,23 +162,21 @@ def _fused_add_rms_norm_backward_kernel(
162162

163163
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
164164

165-
dY_ptr += row_start * dY_row_stride
166-
dX_ptr += row_start * dX_row_stride
167-
if has_dS_out:
168-
dS_out_ptr += row_start * dS_out_row_stride
169-
170-
X_ptr += row_start * X_row_stride
171-
RSTD_ptr += row_start
172-
173165
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
174166
W_row = W_row + offset
175167

176-
for _ in range(row_start, row_end):
177-
dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0)
178-
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
168+
for row_idx in range(row_start, row_end):
169+
dy_base = dY_ptr + row_idx * dY_row_stride
170+
dx_base = dX_ptr + row_idx * dX_row_stride
171+
172+
x_base = X_ptr + row_idx * X_row_stride
173+
rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
174+
175+
dY_row = tl.load(dy_base + col_offsets, mask=mask, other=0.0)
176+
X_row = tl.load(x_base + col_offsets, mask=mask, other=0.0)
179177

180178
# Get cached rms
181-
rstd_row = tl.load(RSTD_ptr)
179+
rstd_row = tl.load(rstd_base)
182180

183181
X_row = X_row.to(tl.float32)
184182

@@ -195,11 +193,11 @@ def _fused_add_rms_norm_backward_kernel(
195193
dX_row = rstd_row * m
196194

197195
if has_dS_out:
198-
dS_out_row = tl.load(dS_out_ptr + col_offsets, mask=mask, other=0.0)
196+
ds_base = dS_out_ptr + row_idx * dS_out_row_stride
197+
dS_out_row = tl.load(ds_base + col_offsets, mask=mask, other=0.0)
199198
dX_row += (rstd_row) * (
200199
-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row
201200
) + dS_out_row
202-
dS_out_ptr += dS_out_row_stride
203201
else:
204202
dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
205203

@@ -210,12 +208,7 @@ def _fused_add_rms_norm_backward_kernel(
210208
# here X_row is already in fp32 (see previous if block)
211209
dW_row += dY_row * (X_row * rstd_row)
212210

213-
tl.store(dX_ptr + col_offsets, dX_row.to(X_dtype), mask=mask)
214-
215-
dY_ptr += dY_row_stride
216-
dX_ptr += dX_row_stride
217-
X_ptr += X_row_stride
218-
RSTD_ptr += RSTD_row_stride
211+
tl.store(dx_base + col_offsets, dX_row.to(X_dtype), mask=mask)
219212

220213
tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
221214

0 commit comments

Comments
 (0)