Skip to content

Commit a324d3a

Browse files
committed
optimize group_norm for ASCEND_NPU
1 parent 781083b commit a324d3a

File tree

1 file changed

+132
-186
lines changed

1 file changed

+132
-186
lines changed

src/liger_kernel/ops/backends/_ascend/ops/group_norm.py

Lines changed: 132 additions & 186 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
# -----------------------------------------------------------------------------
1414

1515

16+
MAX_FUSED_SIZE = 16384
17+
18+
1619
@triton.jit
1720
def _group_norm_forward_kernel(
1821
Y_ptr, # pointer to output, shape (B, G, hidden_size)
@@ -109,139 +112,107 @@ def _group_norm_forward_kernel(
109112

110113
@triton.jit
111114
def _group_norm_backward_kernel(
112-
X_ptr, # pointer to input, shape (B, G, hidden_size)
113-
X_row_stride, # stride of each batch row in X
114-
X_col_stride, # stride of each group row in X
115-
W_ptr, # pointer to affine scale weights, shape (C)
116-
Mean_ptr, # pointer to saved group mean, shape (B, G)
117-
Mean_row_stride, # stride of each batch row in Mean
118-
Mean_col_stride, # stride of each group row in Mean
119-
RSTD_ptr, # pointer to saved reciprocal std, shape (B, G)
120-
DX_ptr, # pointer to input gradients, shape (B, G, hidden_size)
121-
DW_scratch_ptr, # pointer to scratch buffer for dW partial sums, shape (grid, C)
122-
DW_scratch_stride, # row stride for DW_scratch
123-
DB_scratch_ptr, # pointer to scratch buffer for dB partial sums, shape (grid, C)
124-
DB_scratch_stride, # row stride for DB_scratch
125-
DY_ptr, # pointer to upstream gradients, shape (B, G, hidden_size)
126-
DY_row_stride, # stride of each batch row in DY
127-
DY_col_stride, # stride of each group row in DY
128-
n_rows, # total logical rows = B * G
129-
hidden_size,
130-
channels_per_group,
131-
num_groups,
132-
SINGLE_CHANNEL_TILE: tl.constexpr,
133-
COMPUTE_PARAM_GRAD: tl.constexpr,
134-
BLOCK_SIZE_M: tl.constexpr,
135-
BLOCK_SIZE_N: tl.constexpr,
115+
X_ptr,
116+
W_ptr,
117+
Mean_ptr,
118+
RSTD_ptr,
119+
DX_ptr,
120+
DW_partial_ptr,
121+
DB_partial_ptr,
122+
UPSTREAM_ptr,
123+
batch_size,
124+
hidden_size: tl.constexpr,
125+
channels_per_group: tl.constexpr,
126+
num_groups: tl.constexpr,
127+
BLOCK_SIZE: tl.constexpr,
128+
dtype: tl.constexpr,
129+
MAX_CHUNK_SIZE: tl.constexpr = 32,
136130
):
137-
pid = tl.program_id(0)
138-
num_progs = tl.num_programs(0)
131+
prog_id = tl.program_id(0)
132+
num_programs = tl.num_programs(0)
133+
total_tasks = num_groups * batch_size
139134

140-
grid_m = tl.cdiv(n_rows, BLOCK_SIZE_M)
141-
num_col_blocks = tl.cdiv(hidden_size, BLOCK_SIZE_N)
142-
hidden_size_per_channel = hidden_size // channels_per_group
143-
N_inv = 1.0 / hidden_size
144-
row_offsets = tl.arange(0, BLOCK_SIZE_M)
145-
col_offsets_base = tl.arange(0, BLOCK_SIZE_N)
135+
for task_id in tl.range(prog_id, total_tasks, num_programs):
136+
batch_idx = task_id // num_groups
137+
group_idx = task_id % num_groups
146138

147-
if COMPUTE_PARAM_GRAD:
148-
DW_scratch_base = DW_scratch_ptr + pid * DW_scratch_stride
149-
DB_scratch_base = DB_scratch_ptr + pid * DB_scratch_stride
139+
num_channels = num_groups * channels_per_group
140+
X_row_stride = num_channels * hidden_size
150141

151-
# Persistent-program loop over row tiles.
152-
for block_m in tl.range(pid, grid_m, num_progs):
153-
row_idx = block_m * BLOCK_SIZE_M + row_offsets
154-
row_mask = row_idx < n_rows
155-
batch_idx = row_idx // num_groups
156-
group_idx = row_idx % num_groups
142+
X_ptr_task = X_ptr + batch_idx * X_row_stride
143+
DX_ptr_task = DX_ptr + batch_idx * X_row_stride
144+
UPSTREAM_ptr_task = UPSTREAM_ptr + batch_idx * X_row_stride
157145

158-
mean = tl.load(
159-
Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride,
160-
mask=row_mask,
161-
other=0.0,
162-
).to(tl.float32)
163-
rstd = tl.load(
164-
RSTD_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride,
165-
mask=row_mask,
166-
other=0.0,
167-
).to(tl.float32)
168-
169-
sum_x_hat_wdy = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
170-
sum_wdy = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
171-
172-
# Pass 1: compute row-wise reduction terms (c1, c2).
173-
for cb in range(num_col_blocks):
174-
col_offsets = cb * BLOCK_SIZE_N + col_offsets_base
175-
col_mask = col_offsets < hidden_size
176-
mask = row_mask[:, None] & col_mask[None, :]
146+
mean = tl.load(Mean_ptr + batch_idx * num_groups + group_idx)
147+
rstd = tl.load(RSTD_ptr + batch_idx * num_groups + group_idx)
177148

178-
X_ptrs = (
179-
X_ptr + batch_idx[:, None] * X_row_stride + group_idx[:, None] * X_col_stride + col_offsets[None, :]
180-
)
181-
DY_ptrs = (
182-
DY_ptr + batch_idx[:, None] * DY_row_stride + group_idx[:, None] * DY_col_stride + col_offsets[None, :]
183-
)
184-
X_block = tl.load(X_ptrs, mask=mask, other=0.0).to(tl.float32)
185-
DY_block = tl.load(DY_ptrs, mask=mask, other=0.0).to(tl.float32)
149+
c1 = 0.0
150+
c2 = 0.0
151+
block_range = tl.arange(0, BLOCK_SIZE)
186152

187-
if SINGLE_CHANNEL_TILE:
188-
local_channel = (cb * BLOCK_SIZE_N) // hidden_size_per_channel
189-
global_channel = group_idx * channels_per_group + local_channel
190-
W_block = tl.load(W_ptr + global_channel, mask=row_mask, other=0.0).to(tl.float32)[:, None]
191-
else:
192-
local_channel = col_offsets // hidden_size_per_channel
193-
global_channel = group_idx[:, None] * channels_per_group + local_channel[None, :]
194-
W_block = tl.load(W_ptr + global_channel, mask=mask, other=0.0).to(tl.float32)
153+
scratch_base = batch_idx * (num_groups * num_channels) + group_idx * num_channels
154+
group_ch_start = group_idx * channels_per_group
195155

196-
x_hat = (X_block - mean[:, None]) * rstd[:, None]
197-
wdy = W_block * DY_block
198-
sum_x_hat_wdy += tl.sum(tl.where(mask, x_hat * wdy, 0.0), axis=1)
199-
sum_wdy += tl.sum(tl.where(mask, wdy, 0.0), axis=1)
156+
neg_mean = -mean
157+
inv_N = 1.0 / (hidden_size * channels_per_group)
200158

201-
c1 = sum_x_hat_wdy * N_inv
202-
c2 = sum_wdy * N_inv
159+
CHUNK_SIZE = tl.minimum(channels_per_group, MAX_CHUNK_SIZE)
160+
num_chunks = (channels_per_group + CHUNK_SIZE - 1) // CHUNK_SIZE
203161

204-
# Pass 2: compute DX and optionally accumulate DW/DB.
205-
# COMPUTE_PARAM_GRAD=False is used to skip expensive atomics in cases
206-
# where host-side dense reduction is faster/more stable.
207-
for cb in range(num_col_blocks):
208-
col_offsets = cb * BLOCK_SIZE_N + col_offsets_base
209-
col_mask = col_offsets < hidden_size
210-
mask = row_mask[:, None] & col_mask[None, :]
162+
for chunk_idx in tl.range(0, num_chunks):
163+
chunk_start = chunk_idx * CHUNK_SIZE
164+
chunk_end = tl.minimum(chunk_start + CHUNK_SIZE, channels_per_group)
211165

212-
X_ptrs = (
213-
X_ptr + batch_idx[:, None] * X_row_stride + group_idx[:, None] * X_col_stride + col_offsets[None, :]
214-
)
215-
DY_ptrs = (
216-
DY_ptr + batch_idx[:, None] * DY_row_stride + group_idx[:, None] * DY_col_stride + col_offsets[None, :]
217-
)
218-
X_block = tl.load(X_ptrs, mask=mask, other=0.0).to(tl.float32)
219-
DY_block = tl.load(DY_ptrs, mask=mask, other=0.0).to(tl.float32)
166+
for local_ch in tl.range(chunk_start, chunk_end):
167+
W = tl.load(W_ptr + group_ch_start + local_ch)
168+
channel_base = (group_ch_start + local_ch) * hidden_size
169+
dW = 0.0
170+
dB = 0.0
220171

221-
if SINGLE_CHANNEL_TILE:
222-
local_channel = (cb * BLOCK_SIZE_N) // hidden_size_per_channel
223-
global_channel = group_idx * channels_per_group + local_channel
224-
W_block = tl.load(W_ptr + global_channel, mask=row_mask, other=0.0).to(tl.float32)[:, None]
225-
else:
226-
local_channel = col_offsets // hidden_size_per_channel
227-
global_channel = group_idx[:, None] * channels_per_group + local_channel[None, :]
228-
W_block = tl.load(W_ptr + global_channel, mask=mask, other=0.0).to(tl.float32)
172+
for i in tl.range(0, hidden_size, BLOCK_SIZE, num_stages=8):
173+
offsets = i + block_range
174+
mask = offsets < hidden_size
175+
X = tl.load(X_ptr_task + channel_base + offsets, mask=mask, other=0.0)
176+
dy = tl.load(UPSTREAM_ptr_task + channel_base + offsets, mask=mask, other=0.0)
229177

230-
x_hat = (X_block - mean[:, None]) * rstd[:, None]
231-
wdy = W_block * DY_block
232-
DX_block = (wdy - (x_hat * c1[:, None] + c2[:, None])) * rstd[:, None]
178+
x_hat = (X + neg_mean) * rstd
233179

234-
DX_ptrs = (
235-
DX_ptr + batch_idx[:, None] * X_row_stride + group_idx[:, None] * X_col_stride + col_offsets[None, :]
236-
)
237-
tl.store(DX_ptrs, DX_block.to(X_ptr.dtype.element_ty), mask=mask)
180+
dy_float = dy.to(tl.float32)
181+
dy_xh = dy_float * x_hat.to(tl.float32)
182+
183+
tile_xh_dy = tl.sum(dy_xh)
184+
tile_dy = tl.sum(dy_float)
185+
dW += tile_xh_dy
186+
dB += tile_dy
187+
188+
c1 += W * tile_xh_dy
189+
c2 += W * tile_dy
190+
191+
tl.store(DW_partial_ptr + scratch_base + group_ch_start + local_ch, dW)
192+
tl.store(DB_partial_ptr + scratch_base + group_ch_start + local_ch, dB)
193+
194+
c1 = c1 * inv_N
195+
c2 = c2 * inv_N
196+
c1_rstd2 = c1 * rstd * rstd
197+
c2_rstd = c2 * rstd
198+
bias = mean * c1_rstd2 - c2_rstd
238199

239-
if COMPUTE_PARAM_GRAD:
240-
if SINGLE_CHANNEL_TILE:
241-
dW_partial = tl.sum(tl.where(mask, DY_block * x_hat, 0.0), axis=1)
242-
dB_partial = tl.sum(tl.where(mask, DY_block, 0.0), axis=1)
243-
tl.atomic_add(DW_scratch_base + global_channel, dW_partial, mask=row_mask)
244-
tl.atomic_add(DB_scratch_base + global_channel, dB_partial, mask=row_mask)
200+
for chunk_idx in tl.range(0, num_chunks):
201+
chunk_start = chunk_idx * CHUNK_SIZE
202+
chunk_end = tl.minimum(chunk_start + CHUNK_SIZE, channels_per_group)
203+
204+
for local_ch in tl.range(chunk_start, chunk_end):
205+
W = tl.load(W_ptr + group_ch_start + local_ch)
206+
W_rstd = W * rstd
207+
channel_base = (group_ch_start + local_ch) * hidden_size
208+
209+
for i in tl.range(0, hidden_size, BLOCK_SIZE, num_stages=8):
210+
offsets = i + block_range
211+
mask = offsets < hidden_size
212+
X = tl.load(X_ptr_task + channel_base + offsets, mask=mask, other=0.0)
213+
dy = tl.load(UPSTREAM_ptr_task + channel_base + offsets, mask=mask, other=0.0)
214+
dx = W_rstd * dy - X * c1_rstd2 + bias
215+
tl.store(DX_ptr_task + channel_base + offsets, dx.to(dtype), mask=mask)
245216

246217

247218
# -----------------------------------------------------------------------------
@@ -341,88 +312,63 @@ def group_norm_forward(X, num_channels, num_groups, W, B, eps):
341312
def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups):
342313
shape = dY.shape
343314
batch_size = shape[0]
315+
hidden_size = dY.shape[-1]
344316
channels_per_group = num_channels // num_groups
345-
X_grouped = X.view(batch_size, num_groups, -1)
346-
dY_grouped = dY.view(batch_size, num_groups, -1)
347-
hidden_size = dY_grouped.shape[-1]
348-
hidden_size_per_channel = hidden_size // channels_per_group
349-
n_rows = batch_size * num_groups
350317

351-
BLOCK_SIZE_N = min(128, triton.next_power_of_2(hidden_size))
352-
BLOCK_SIZE_M = get_optimal_block_size(
353-
n_rows,
354-
X.element_size(),
355-
BLOCK_SIZE_N,
356-
is_backward=True,
318+
dY = dY.view(batch_size, num_groups, -1)
319+
DX = torch.empty(
320+
(batch_size, num_groups, hidden_size * channels_per_group),
321+
dtype=X.dtype,
322+
device=X.device,
357323
)
358324

359-
# Same condition as forward:
360-
# if true, each BLOCK_SIZE_N tile maps cleanly to one channel segment.
361-
single_channel_tile = BLOCK_SIZE_N <= hidden_size_per_channel and hidden_size_per_channel % BLOCK_SIZE_N == 0
325+
_DW_partial = torch.zeros(
326+
(batch_size, num_groups, num_channels),
327+
dtype=torch.float32,
328+
device=W.device,
329+
)
330+
_DB_partial = torch.zeros(
331+
(batch_size, num_groups, num_channels),
332+
dtype=torch.float32,
333+
device=B.device,
334+
)
362335

363-
num_cores = get_npu_core_count()
364-
grid = min(num_cores, triton.cdiv(n_rows, BLOCK_SIZE_M))
365-
# For non-single-channel tiles, per-element atomic updates are costly.
366-
# In that case, kernel computes DX only and DW/DB are reduced on host side.
367-
compute_param_grad = single_channel_tile
368-
369-
DX = torch.empty((batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device)
370-
if compute_param_grad:
371-
DW_scratch = torch.zeros((grid, num_channels), dtype=torch.float32, device=W.device)
372-
DB_scratch = torch.zeros((grid, num_channels), dtype=torch.float32, device=W.device)
373-
else:
374-
# Not used when COMPUTE_PARAM_GRAD=False.
375-
# Intentionally set to None to enforce fail-fast behavior if accidentally accessed.
376-
DW_scratch = None
377-
DB_scratch = None
378-
379-
_group_norm_backward_kernel[(grid,)](
380-
X_grouped,
381-
X_grouped.stride(0),
382-
X_grouped.stride(1),
336+
element_size = dY.element_size()
337+
vv_alignment = 32
338+
required_elem_alignment = vv_alignment // element_size
339+
340+
BLOCK_SIZE = 512
341+
BLOCK_SIZE = max(BLOCK_SIZE, required_elem_alignment)
342+
BLOCK_SIZE = min(BLOCK_SIZE, MAX_FUSED_SIZE)
343+
344+
triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16
345+
346+
aicore_num = 48
347+
total_tasks = batch_size * num_groups
348+
grid_size = min(aicore_num, total_tasks)
349+
grid_size = max(grid_size, 1)
350+
grid = (grid_size,)
351+
352+
_group_norm_backward_kernel[grid](
353+
X,
383354
W,
384355
Mean,
385-
Mean.stride(0),
386-
Mean.stride(1),
387356
RSTD,
388357
DX,
389-
DW_scratch,
390-
0 if not compute_param_grad else DW_scratch.stride(0),
391-
DB_scratch,
392-
0 if not compute_param_grad else DB_scratch.stride(0),
393-
dY_grouped,
394-
dY_grouped.stride(0),
395-
dY_grouped.stride(1),
396-
n_rows,
358+
_DW_partial,
359+
_DB_partial,
360+
dY,
361+
batch_size,
397362
hidden_size,
398363
channels_per_group,
399364
num_groups,
400-
SINGLE_CHANNEL_TILE=single_channel_tile,
401-
COMPUTE_PARAM_GRAD=compute_param_grad,
402-
BLOCK_SIZE_M=BLOCK_SIZE_M,
403-
BLOCK_SIZE_N=BLOCK_SIZE_N,
365+
BLOCK_SIZE=BLOCK_SIZE,
366+
dtype=triton_dtype,
367+
MAX_CHUNK_SIZE=32,
404368
)
405369

406-
# Precision note:
407-
# - In-kernel atomic_add on floating-point values is order-dependent under parallel
408-
# scheduling (non-associative summation), which can introduce run-to-run numerical
409-
# differences in DW/DB for contention-heavy shapes.
410-
# - Host-side dense reduction provides a more stable accumulation pattern for these
411-
# difficult layouts.
412-
if compute_param_grad:
413-
DW = DW_scratch.sum(dim=0).to(W.dtype)
414-
DB = DB_scratch.sum(dim=0).to(W.dtype)
415-
else:
416-
# Fallback path to avoid severe atomic contention when SINGLE_CHANNEL_TILE=False.
417-
# Layout: [B, G, hidden_size] -> [B, G, C_per_G, hidden_per_channel]
418-
X4 = X_grouped.reshape(batch_size, num_groups, channels_per_group, hidden_size_per_channel).to(torch.float32)
419-
dY4 = dY_grouped.reshape(batch_size, num_groups, channels_per_group, hidden_size_per_channel).to(torch.float32)
420-
mean4 = Mean.reshape(batch_size, num_groups, 1, 1).to(torch.float32)
421-
rstd4 = RSTD.reshape(batch_size, num_groups, 1, 1).to(torch.float32)
422-
423-
x_hat4 = (X4 - mean4) * rstd4
424-
DW = (dY4 * x_hat4).sum(dim=(0, 3)).reshape(-1).to(W.dtype)
425-
DB = dY4.sum(dim=(0, 3)).reshape(-1).to(W.dtype)
370+
DW = _DW_partial.sum(dim=(0, 1)).to(W.dtype)
371+
DB = _DB_partial.sum(dim=(0, 1)).to(B.dtype)
426372

427373
return DX.view(*shape), DW, DB
428374

0 commit comments

Comments
 (0)