Skip to content

Commit 2ce9bdd

Browse files
authored
avoid pointer mutation in layer_norm kernel (#1006)
## Summary Rewrite layer_norm kernel to use explicit channel offsets instead of mutating X/Y base pointers inside loops. This improves Triton compiler optimization opportunities, enables more predictable memory access patterns, and avoids loop-carried pointer dependencies. ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <img width="1810" height="448" alt="image" src="https://github.com/user-attachments/assets/958b2c21-2f1c-42a2-af56-49a22e5c1b83" /> - Hardware Type: Ascend NPU 910B4 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence
1 parent 41d4bcf commit 2ce9bdd

File tree

1 file changed

+11
-13
lines changed

1 file changed

+11
-13
lines changed

src/liger_kernel/ops/layer_norm.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from liger_kernel.ops.utils import calculate_settings
99
from liger_kernel.ops.utils import compare_version
1010
from liger_kernel.ops.utils import ensure_contiguous
11+
from liger_kernel.utils import get_npu_multi_processor_count
1112
from liger_kernel.utils import is_npu_available
1213

1314
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
@@ -124,14 +125,14 @@ def _layer_norm_backward_kernel(
124125
w = tl.load(W_ptr + cols, mask=mask, other=0.0)
125126
w_f32 = w.to(tl.float32)
126127

127-
# Calculate pointers for this specific row
128-
row_X_ptr = X_ptr + row_start * stride_x
129-
row_DX_ptr = DX_ptr + row_start * stride_dx
130-
row_DY_ptr = DY_ptr + row_start * stride_dy
131-
row_Mean_ptr = Mean_ptr + row_start
132-
row_RSTD_ptr = RSTD_ptr + row_start
128+
for row_idx in range(row_start, row_end):
129+
# Calculate pointers for this specific row
130+
row_X_ptr = X_ptr + row_idx * stride_x
131+
row_DX_ptr = DX_ptr + row_idx * stride_dx
132+
row_DY_ptr = DY_ptr + row_idx * stride_dy
133+
row_Mean_ptr = Mean_ptr + row_idx * stride_mean
134+
row_RSTD_ptr = RSTD_ptr + row_idx * stride_rstd
133135

134-
for _ in range(row_start, row_end):
135136
# Load data for this row
136137
x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
137138
dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
@@ -160,12 +161,6 @@ def _layer_norm_backward_kernel(
160161
dW_row += dw
161162
db_row += db
162163

163-
row_X_ptr += stride_x
164-
row_DX_ptr += stride_dx
165-
row_DY_ptr += stride_dy
166-
row_Mean_ptr += stride_mean
167-
row_RSTD_ptr += stride_rstd
168-
169164
tl.store(DW_ptr + row_block_id * stride_dw + cols, dW_row, mask=mask)
170165
tl.store(DB_ptr + row_block_id * stride_db + cols, db_row, mask=mask)
171166

@@ -254,6 +249,8 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
254249
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
255250
elif X.device.type == "xpu":
256251
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
252+
elif X.device.type == "npu":
253+
sm_count = get_npu_multi_processor_count()
257254

258255
# fp32 for numerical stability especially.
259256
_DW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
@@ -301,6 +298,7 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
301298
DX = DX.view(*shape)
302299
DW = _DW.sum(dim=0).to(W.dtype)
303300
DB = _DB.sum(dim=0).to(B.dtype)
301+
304302
return DX, DW, DB
305303

306304

0 commit comments

Comments
 (0)