diff --git a/src/liger_kernel/ops/backends/_ascend/ops/group_norm.py b/src/liger_kernel/ops/backends/_ascend/ops/group_norm.py index 85d9efc75..c160b70ce 100644 --- a/src/liger_kernel/ops/backends/_ascend/ops/group_norm.py +++ b/src/liger_kernel/ops/backends/_ascend/ops/group_norm.py @@ -13,235 +13,198 @@ # ----------------------------------------------------------------------------- +MAX_FUSED_SIZE = 16384 + + @triton.jit def _group_norm_forward_kernel( - Y_ptr, # pointer to output, shape (B, G, hidden_size) - Y_row_stride, # stride of each batch row in Y - Y_col_stride, # stride of each group row in Y - X_ptr, # pointer to input, shape (B, G, hidden_size) - X_row_stride, # stride of each batch row in X - X_col_stride, # stride of each group row in X - Mean_ptr, # pointer to mean output, shape (B, G) - Mean_row_stride, # stride of each batch row in Mean - Mean_col_stride, # stride of each group row in Mean - RSTD_ptr, # pointer to rstd output, shape (B, G) - RSTD_row_stride, # stride of each batch row in RSTD - RSTD_col_stride, # stride of each group row in RSTD - W_ptr, # pointer to affine scale weights, shape (C) - B_ptr, # pointer to affine bias weights, shape (C) - n_rows, # total logical rows = B * G - hidden_size, - channels_per_group, - num_groups, - SINGLE_CHANNEL_TILE: tl.constexpr, + Y_ptr, # pointer to output, shape (n_rows, n_groups, hidden_size) + X_ptr, # pointer to input, shape (n_rows, n_groups, hidden_size) + Mean_ptr, # pointer to mean, shape (n_rows, n_groups) + RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups) + W_ptr, # pointer to W + B_ptr, # pointer to B + batch_size, # batch size + num_groups, # number of groups + hidden_size: tl.constexpr, # hidden size of X + channels_per_group: tl.constexpr, # the number of channels per group eps, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + MAX_CHUNK_SIZE: tl.constexpr = 128, ): - pid = tl.program_id(0) - num_progs = tl.num_programs(0) - - grid_m = tl.cdiv(n_rows, BLOCK_SIZE_M) - num_col_blocks = tl.cdiv(hidden_size, BLOCK_SIZE_N) - hidden_size_per_channel = hidden_size // channels_per_group - hidden_size_inv = 1.0 / hidden_size - row_offsets = tl.arange(0, BLOCK_SIZE_M) - col_offsets_base = tl.arange(0, BLOCK_SIZE_N) - - # Persistent-program loop over row tiles. - for block_m in tl.range(pid, grid_m, num_progs): - row_idx = block_m * BLOCK_SIZE_M + row_offsets - row_mask = row_idx < n_rows - batch_idx = row_idx // num_groups - group_idx = row_idx % num_groups - - row_sum = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) - row_square_sum = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) - - # Pass 1: accumulate E[x] and E[x^2] for each row tile. - for cb in range(num_col_blocks): - col_offsets = cb * BLOCK_SIZE_N + col_offsets_base - col_mask = col_offsets < hidden_size - mask = row_mask[:, None] & col_mask[None, :] - X_ptrs = ( - X_ptr + batch_idx[:, None] * X_row_stride + group_idx[:, None] * X_col_stride + col_offsets[None, :] - ) - X_block = tl.load(X_ptrs, mask=mask, other=0.0).to(tl.float32) - row_sum += tl.sum(X_block, axis=1) - row_square_sum += tl.sum(X_block * X_block, axis=1) - - mean = row_sum * hidden_size_inv - var = row_square_sum * hidden_size_inv - mean * mean - rstd = rsqrt(tl.maximum(var, 0.0) + eps) - - mean_ptrs = Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride - rstd_ptrs = RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride - tl.store(mean_ptrs, mean, mask=row_mask) - tl.store(rstd_ptrs, rstd, mask=row_mask) - - # Pass 2: normalize + affine transform. - # SINGLE_CHANNEL_TILE indicates the current column tile maps to one channel, - # so W/B can be loaded once per row and broadcast to the tile. - for cb in range(num_col_blocks): - col_offsets = cb * BLOCK_SIZE_N + col_offsets_base - col_mask = col_offsets < hidden_size - mask = row_mask[:, None] & col_mask[None, :] - X_ptrs = ( - X_ptr + batch_idx[:, None] * X_row_stride + group_idx[:, None] * X_col_stride + col_offsets[None, :] - ) - X_block = tl.load(X_ptrs, mask=mask, other=0.0).to(tl.float32) - if SINGLE_CHANNEL_TILE: - local_channel = (cb * BLOCK_SIZE_N) // hidden_size_per_channel - global_channel = group_idx * channels_per_group + local_channel - W_block = tl.load(W_ptr + global_channel, mask=row_mask, other=0.0).to(tl.float32)[:, None] - B_block = tl.load(B_ptr + global_channel, mask=row_mask, other=0.0).to(tl.float32)[:, None] - else: - local_channel = col_offsets // hidden_size_per_channel - global_channel = group_idx[:, None] * channels_per_group + local_channel[None, :] - W_block = tl.load(W_ptr + global_channel, mask=mask, other=0.0).to(tl.float32) - B_block = tl.load(B_ptr + global_channel, mask=mask, other=0.0).to(tl.float32) - Y_block = (X_block - mean[:, None]) * rstd[:, None] * W_block + B_block - Y_ptrs = ( - Y_ptr + batch_idx[:, None] * Y_row_stride + group_idx[:, None] * Y_col_stride + col_offsets[None, :] - ) - tl.store(Y_ptrs, Y_block, mask=mask) + """ + References: + https://nn.labml.ai/normalization/group_norm/index.html + """ + # True only when the last tile is partial; constexpr so the branch folds away + # at compile time when hidden_size is a power-of-2 multiple of BLOCK_SIZE. + hidden_size_per_channel: tl.constexpr = hidden_size // channels_per_group + + batch_idx = tl.program_id(0) + group_idx = tl.program_id(1) + + if batch_idx >= batch_size or group_idx >= num_groups: + return + + offset = batch_idx * num_groups + group_idx + + X_ptr_task = X_ptr + batch_idx * hidden_size * num_groups + group_idx * hidden_size + Y_ptr_task = Y_ptr + batch_idx * hidden_size * num_groups + group_idx * hidden_size + + block_range = tl.arange(0, BLOCK_SIZE) + + sum_acc = 0.0 + sum_sq_acc = 0.0 + + for i in tl.range(0, hidden_size, BLOCK_SIZE, num_stages=16): + hidden_size_offsets = i + block_range + mask = hidden_size_offsets < hidden_size + X = tl.load(X_ptr_task + hidden_size_offsets, mask=mask, other=0.0) + + sum_acc += tl.sum(X) + sum_sq_acc += tl.sum(X * X) + + inv_hidden_size = 1.0 / hidden_size + mean = sum_acc * inv_hidden_size + variance = (sum_sq_acc - sum_acc * mean) * inv_hidden_size + rstd = rsqrt(variance + eps) + + group_ch_start = group_idx * channels_per_group + + neg_mean_rstd = -mean * rstd + + CHUNK_SIZE = tl.minimum(channels_per_group, MAX_CHUNK_SIZE) + num_chunks = (channels_per_group + CHUNK_SIZE - 1) // CHUNK_SIZE + + for chunk_idx in tl.range(0, num_chunks): + chunk_start = chunk_idx * CHUNK_SIZE + chunk_end = tl.minimum(chunk_start + CHUNK_SIZE, channels_per_group) + + for local_ch in tl.range(chunk_start, chunk_end): + W = tl.load(W_ptr + group_ch_start + local_ch) + B = tl.load(B_ptr + group_ch_start + local_ch) + + scale = W * rstd + bias = B + W * neg_mean_rstd + channel_base = local_ch * hidden_size_per_channel + + for i in tl.range(0, hidden_size_per_channel, BLOCK_SIZE, num_stages=16): + offsets = i + block_range + mask = offsets < hidden_size_per_channel + X = tl.load(X_ptr_task + channel_base + offsets, mask=mask, other=0.0) + Y = X * scale + bias + tl.store(Y_ptr_task + channel_base + offsets, Y, mask=mask) + + tl.store(Mean_ptr + offset, mean) + tl.store(RSTD_ptr + offset, rstd) @triton.jit def _group_norm_backward_kernel( - X_ptr, # pointer to input, shape (B, G, hidden_size) - X_row_stride, # stride of each batch row in X - X_col_stride, # stride of each group row in X - W_ptr, # pointer to affine scale weights, shape (C) - Mean_ptr, # pointer to saved group mean, shape (B, G) - Mean_row_stride, # stride of each batch row in Mean - Mean_col_stride, # stride of each group row in Mean - RSTD_ptr, # pointer to saved reciprocal std, shape (B, G) - DX_ptr, # pointer to input gradients, shape (B, G, hidden_size) - DW_scratch_ptr, # pointer to scratch buffer for dW partial sums, shape (grid, C) - DW_scratch_stride, # row stride for DW_scratch - DB_scratch_ptr, # pointer to scratch buffer for dB partial sums, shape (grid, C) - DB_scratch_stride, # row stride for DB_scratch - DY_ptr, # pointer to upstream gradients, shape (B, G, hidden_size) - DY_row_stride, # stride of each batch row in DY - DY_col_stride, # stride of each group row in DY - n_rows, # total logical rows = B * G - hidden_size, - channels_per_group, - num_groups, - SINGLE_CHANNEL_TILE: tl.constexpr, - COMPUTE_PARAM_GRAD: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, + X_ptr, + W_ptr, + Mean_ptr, + RSTD_ptr, + DX_ptr, + DW_partial_ptr, + DB_partial_ptr, + UPSTREAM_ptr, + batch_size, + hidden_size: tl.constexpr, + channels_per_group: tl.constexpr, + num_groups: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + dtype: tl.constexpr, + MAX_CHUNK_SIZE: tl.constexpr = 128, ): - pid = tl.program_id(0) - num_progs = tl.num_programs(0) - - grid_m = tl.cdiv(n_rows, BLOCK_SIZE_M) - num_col_blocks = tl.cdiv(hidden_size, BLOCK_SIZE_N) - hidden_size_per_channel = hidden_size // channels_per_group - N_inv = 1.0 / hidden_size - row_offsets = tl.arange(0, BLOCK_SIZE_M) - col_offsets_base = tl.arange(0, BLOCK_SIZE_N) - - if COMPUTE_PARAM_GRAD: - DW_scratch_base = DW_scratch_ptr + pid * DW_scratch_stride - DB_scratch_base = DB_scratch_ptr + pid * DB_scratch_stride - - # Persistent-program loop over row tiles. - for block_m in tl.range(pid, grid_m, num_progs): - row_idx = block_m * BLOCK_SIZE_M + row_offsets - row_mask = row_idx < n_rows - batch_idx = row_idx // num_groups - group_idx = row_idx % num_groups - - mean = tl.load( - Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, - mask=row_mask, - other=0.0, - ).to(tl.float32) - rstd = tl.load( - RSTD_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, - mask=row_mask, - other=0.0, - ).to(tl.float32) - - sum_x_hat_wdy = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) - sum_wdy = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) - - # Pass 1: compute row-wise reduction terms (c1, c2). - for cb in range(num_col_blocks): - col_offsets = cb * BLOCK_SIZE_N + col_offsets_base - col_mask = col_offsets < hidden_size - mask = row_mask[:, None] & col_mask[None, :] - - X_ptrs = ( - X_ptr + batch_idx[:, None] * X_row_stride + group_idx[:, None] * X_col_stride + col_offsets[None, :] - ) - DY_ptrs = ( - DY_ptr + batch_idx[:, None] * DY_row_stride + group_idx[:, None] * DY_col_stride + col_offsets[None, :] - ) - X_block = tl.load(X_ptrs, mask=mask, other=0.0).to(tl.float32) - DY_block = tl.load(DY_ptrs, mask=mask, other=0.0).to(tl.float32) - - if SINGLE_CHANNEL_TILE: - local_channel = (cb * BLOCK_SIZE_N) // hidden_size_per_channel - global_channel = group_idx * channels_per_group + local_channel - W_block = tl.load(W_ptr + global_channel, mask=row_mask, other=0.0).to(tl.float32)[:, None] - else: - local_channel = col_offsets // hidden_size_per_channel - global_channel = group_idx[:, None] * channels_per_group + local_channel[None, :] - W_block = tl.load(W_ptr + global_channel, mask=mask, other=0.0).to(tl.float32) - - x_hat = (X_block - mean[:, None]) * rstd[:, None] - wdy = W_block * DY_block - sum_x_hat_wdy += tl.sum(tl.where(mask, x_hat * wdy, 0.0), axis=1) - sum_wdy += tl.sum(tl.where(mask, wdy, 0.0), axis=1) - - c1 = sum_x_hat_wdy * N_inv - c2 = sum_wdy * N_inv - - # Pass 2: compute DX and optionally accumulate DW/DB. - # COMPUTE_PARAM_GRAD=False is used to skip expensive atomics in cases - # where host-side dense reduction is faster/more stable. - for cb in range(num_col_blocks): - col_offsets = cb * BLOCK_SIZE_N + col_offsets_base - col_mask = col_offsets < hidden_size - mask = row_mask[:, None] & col_mask[None, :] - - X_ptrs = ( - X_ptr + batch_idx[:, None] * X_row_stride + group_idx[:, None] * X_col_stride + col_offsets[None, :] - ) - DY_ptrs = ( - DY_ptr + batch_idx[:, None] * DY_row_stride + group_idx[:, None] * DY_col_stride + col_offsets[None, :] - ) - X_block = tl.load(X_ptrs, mask=mask, other=0.0).to(tl.float32) - DY_block = tl.load(DY_ptrs, mask=mask, other=0.0).to(tl.float32) - - if SINGLE_CHANNEL_TILE: - local_channel = (cb * BLOCK_SIZE_N) // hidden_size_per_channel - global_channel = group_idx * channels_per_group + local_channel - W_block = tl.load(W_ptr + global_channel, mask=row_mask, other=0.0).to(tl.float32)[:, None] - else: - local_channel = col_offsets // hidden_size_per_channel - global_channel = group_idx[:, None] * channels_per_group + local_channel[None, :] - W_block = tl.load(W_ptr + global_channel, mask=mask, other=0.0).to(tl.float32) - - x_hat = (X_block - mean[:, None]) * rstd[:, None] - wdy = W_block * DY_block - DX_block = (wdy - (x_hat * c1[:, None] + c2[:, None])) * rstd[:, None] - - DX_ptrs = ( - DX_ptr + batch_idx[:, None] * X_row_stride + group_idx[:, None] * X_col_stride + col_offsets[None, :] - ) - tl.store(DX_ptrs, DX_block.to(X_ptr.dtype.element_ty), mask=mask) - - if COMPUTE_PARAM_GRAD: - if SINGLE_CHANNEL_TILE: - dW_partial = tl.sum(tl.where(mask, DY_block * x_hat, 0.0), axis=1) - dB_partial = tl.sum(tl.where(mask, DY_block, 0.0), axis=1) - tl.atomic_add(DW_scratch_base + global_channel, dW_partial, mask=row_mask) - tl.atomic_add(DB_scratch_base + global_channel, dB_partial, mask=row_mask) + """ + Backward kernel for Group Normalization. + Uses 2D grid (batch, group) for better parallelism. + """ + hidden_size_per_channel: tl.constexpr = hidden_size // channels_per_group + + batch_idx = tl.program_id(0) + group_idx = tl.program_id(1) + + if batch_idx >= batch_size or group_idx >= num_groups: + return + + num_channels = num_groups * channels_per_group + + X_ptr_task = X_ptr + batch_idx * hidden_size * num_groups + group_idx * hidden_size + DX_ptr_task = DX_ptr + batch_idx * hidden_size * num_groups + group_idx * hidden_size + UPSTREAM_ptr_task = UPSTREAM_ptr + batch_idx * hidden_size * num_groups + group_idx * hidden_size + + mean = tl.load(Mean_ptr + batch_idx * num_groups + group_idx) + rstd = tl.load(RSTD_ptr + batch_idx * num_groups + group_idx) + + c1 = 0.0 + c2 = 0.0 + block_range = tl.arange(0, BLOCK_SIZE) + + scratch_base = batch_idx * (num_groups * num_channels) + group_idx * num_channels + group_ch_start = group_idx * channels_per_group + + neg_mean = -mean + inv_N = 1.0 / hidden_size + + CHUNK_SIZE = tl.minimum(channels_per_group, MAX_CHUNK_SIZE) + num_chunks = (channels_per_group + CHUNK_SIZE - 1) // CHUNK_SIZE + + for chunk_idx in tl.range(0, num_chunks): + chunk_start = chunk_idx * CHUNK_SIZE + chunk_end = tl.minimum(chunk_start + CHUNK_SIZE, channels_per_group) + + for local_ch in tl.range(chunk_start, chunk_end): + W = tl.load(W_ptr + group_ch_start + local_ch) + channel_base = local_ch * hidden_size_per_channel + dW = 0.0 + dB = 0.0 + + for i in tl.range(0, hidden_size_per_channel, BLOCK_SIZE, num_stages=16): + offsets = i + block_range + mask = offsets < hidden_size_per_channel + X = tl.load(X_ptr_task + channel_base + offsets, mask=mask, other=0.0) + dy = tl.load(UPSTREAM_ptr_task + channel_base + offsets, mask=mask, other=0.0) + + x_hat = (X + neg_mean) * rstd + + dy_float = dy.to(tl.float32) + dy_xh = dy_float * x_hat.to(tl.float32) + + tile_xh_dy = tl.sum(dy_xh) + tile_dy = tl.sum(dy_float) + dW += tile_xh_dy + dB += tile_dy + + c1 += W * tile_xh_dy + c2 += W * tile_dy + + tl.store(DW_partial_ptr + scratch_base + group_ch_start + local_ch, dW) + tl.store(DB_partial_ptr + scratch_base + group_ch_start + local_ch, dB) + + c1 = c1 * inv_N + c2 = c2 * inv_N + c1_rstd2 = c1 * rstd * rstd + c2_rstd = c2 * rstd + bias = mean * c1_rstd2 - c2_rstd + + for chunk_idx in tl.range(0, num_chunks): + chunk_start = chunk_idx * CHUNK_SIZE + chunk_end = tl.minimum(chunk_start + CHUNK_SIZE, channels_per_group) + + for local_ch in tl.range(chunk_start, chunk_end): + W = tl.load(W_ptr + group_ch_start + local_ch) + W_rstd = W * rstd + channel_base = local_ch * hidden_size_per_channel + + for i in tl.range(0, hidden_size_per_channel, BLOCK_SIZE, num_stages=16): + offsets = i + block_range + mask = offsets < hidden_size_per_channel + X = tl.load(X_ptr_task + channel_base + offsets, mask=mask, other=0.0) + dy = tl.load(UPSTREAM_ptr_task + channel_base + offsets, mask=mask, other=0.0) + dx = W_rstd * dy - X * c1_rstd2 + bias + tl.store(DX_ptr_task + channel_base + offsets, dx.to(dtype), mask=mask) # ----------------------------------------------------------------------------- @@ -268,73 +231,72 @@ def get_optimal_block_size(n_rows, dtype_size, BLOCK_SIZE_N, is_backward: bool = def group_norm_forward(X, num_channels, num_groups, W, B, eps): + """ + NPU-optimized group_norm forward. + + Uses UB-aware BLOCK_SIZE (memory_multiplier=5.0 covers X + W + B + Y + + intermediate registers in the normalization pass). + The Triton kernel itself is the same as the generic path. + """ + shape = X.shape batch_size = shape[0] channels_per_group = num_channels // num_groups - # Reshape X so that the mean / std are computed across each group - X = X.view(batch_size, num_groups, -1).contiguous() + X = X.view(batch_size, num_groups, -1).contiguous() hidden_size = X.shape[-1] - hidden_size_per_channel = hidden_size // channels_per_group - n_rows = batch_size * num_groups - - BLOCK_SIZE_N = min(128, triton.next_power_of_2(hidden_size)) - BLOCK_SIZE_M = get_optimal_block_size(n_rows, X.element_size(), BLOCK_SIZE_N) - - # Fast path condition: each column tile must lie entirely inside one channel - # segment of length `hidden_size_per_channel`. - # - # Layout of a row: - # | channel0 | channel1 | channel2 | ... - # |----Hc----|----Hc----| - # Hc = hidden_size_per_channel - # - # The kernel processes tiles of shape (BLOCK_SIZE_M, BLOCK_SIZE_N). - # Channel boundaries exist only along the column dimension, because - # each row corresponds to a different (batch, group). - # - # Therefore only BLOCK_SIZE_N matters for whether a tile crosses - # channel boundaries; BLOCK_SIZE_M does not affect channel mapping. - # - # If BLOCK_SIZE_N divides Hc and is <= Hc, each column tile belongs - # to exactly one channel. In that case W/B can be loaded once and - # broadcast across the tile (fast path). - # - # Otherwise a tile may span multiple channels, requiring per-element - # channel index computation and parameter loads (slow path). - single_channel_tile = BLOCK_SIZE_N <= hidden_size_per_channel and hidden_size_per_channel % BLOCK_SIZE_N == 0 - - num_cores = get_npu_core_count() - grid = min(num_cores, triton.cdiv(n_rows, BLOCK_SIZE_M)) + + element_size = X.element_size() + vv_alignment = 32 + required_elem_alignment = vv_alignment // element_size + + if hidden_size >= 2048: + BLOCK_SIZE = 1024 + elif hidden_size >= 1024: + BLOCK_SIZE = 512 + elif hidden_size >= 512: + BLOCK_SIZE = 256 + else: + BLOCK_SIZE = 128 + + BLOCK_SIZE = max(BLOCK_SIZE, required_elem_alignment) + BLOCK_SIZE = min(BLOCK_SIZE, MAX_FUSED_SIZE) + + if hidden_size % required_elem_alignment != 0: + print( + f"⚠️ Tail axis {hidden_size} does not meet alignment requirements (must be divisible by {required_elem_alignment})" + ) Y = torch.empty((batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device) - Mean = torch.empty((batch_size, num_groups), dtype=X.dtype, device=X.device) - RSTD = torch.empty((batch_size, num_groups), dtype=X.dtype, device=X.device) + Mean = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device) + RSTD = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device) + + aicore_num = get_npu_core_count() + + grid_batch = batch_size + grid_group = min(num_groups, aicore_num) + + grid_batch = max(grid_batch, 1) + grid_group = max(grid_group, 1) - _group_norm_forward_kernel[(grid,)]( + grid = (grid_batch, grid_group) + + _group_norm_forward_kernel[grid]( Y, - Y.stride(0), - Y.stride(1), X, - X.stride(0), - X.stride(1), Mean, - Mean.stride(0), - Mean.stride(1), RSTD, - RSTD.stride(0), - RSTD.stride(1), W, B, - n_rows, + batch_size, + num_groups, hidden_size, channels_per_group, - num_groups, - SINGLE_CHANNEL_TILE=single_channel_tile, - eps=eps, - BLOCK_SIZE_M=BLOCK_SIZE_M, - BLOCK_SIZE_N=BLOCK_SIZE_N, + eps, + BLOCK_SIZE=BLOCK_SIZE, + MAX_CHUNK_SIZE=128, ) + return Y.view(*shape), X.view(*shape), Mean, RSTD @@ -342,87 +304,76 @@ def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups): shape = dY.shape batch_size = shape[0] channels_per_group = num_channels // num_groups - X_grouped = X.view(batch_size, num_groups, -1) - dY_grouped = dY.view(batch_size, num_groups, -1) - hidden_size = dY_grouped.shape[-1] - hidden_size_per_channel = hidden_size // channels_per_group - n_rows = batch_size * num_groups - - BLOCK_SIZE_N = min(128, triton.next_power_of_2(hidden_size)) - BLOCK_SIZE_M = get_optimal_block_size( - n_rows, - X.element_size(), - BLOCK_SIZE_N, - is_backward=True, + + dY = dY.view(batch_size, num_groups, -1).contiguous() + X = X.view(batch_size, num_groups, -1).contiguous() + hidden_size = dY.shape[-1] + + DX = torch.empty( + (batch_size, num_groups, hidden_size), + dtype=X.dtype, + device=X.device, ) - # Same condition as forward: - # if true, each BLOCK_SIZE_N tile maps cleanly to one channel segment. - single_channel_tile = BLOCK_SIZE_N <= hidden_size_per_channel and hidden_size_per_channel % BLOCK_SIZE_N == 0 + _DW_partial = torch.zeros( + (batch_size, num_groups, num_channels), + dtype=torch.float32, + device=W.device, + ) + _DB_partial = torch.zeros( + (batch_size, num_groups, num_channels), + dtype=torch.float32, + device=B.device, + ) - num_cores = get_npu_core_count() - grid = min(num_cores, triton.cdiv(n_rows, BLOCK_SIZE_M)) - # For non-single-channel tiles, per-element atomic updates are costly. - # In that case, kernel computes DX only and DW/DB are reduced on host side. - compute_param_grad = single_channel_tile + element_size = dY.element_size() + vv_alignment = 32 + required_elem_alignment = vv_alignment // element_size - DX = torch.empty((batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device) - if compute_param_grad: - DW_scratch = torch.zeros((grid, num_channels), dtype=torch.float32, device=W.device) - DB_scratch = torch.zeros((grid, num_channels), dtype=torch.float32, device=W.device) + if hidden_size >= 2048: + BLOCK_SIZE = 1024 + elif hidden_size >= 1024: + BLOCK_SIZE = 512 + elif hidden_size >= 512: + BLOCK_SIZE = 256 else: - # Not used when COMPUTE_PARAM_GRAD=False. - # Intentionally set to None to enforce fail-fast behavior if accidentally accessed. - DW_scratch = None - DB_scratch = None - - _group_norm_backward_kernel[(grid,)]( - X_grouped, - X_grouped.stride(0), - X_grouped.stride(1), + BLOCK_SIZE = 128 + + BLOCK_SIZE = max(BLOCK_SIZE, required_elem_alignment) + BLOCK_SIZE = min(BLOCK_SIZE, MAX_FUSED_SIZE) + + triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16 + + aicore_num = get_npu_core_count() + + grid_batch = batch_size + grid_group = min(num_groups, aicore_num) + + grid_batch = max(grid_batch, 1) + grid_group = max(grid_group, 1) + + grid = (grid_batch, grid_group) + + _group_norm_backward_kernel[grid]( + X, W, Mean, - Mean.stride(0), - Mean.stride(1), RSTD, DX, - DW_scratch, - 0 if not compute_param_grad else DW_scratch.stride(0), - DB_scratch, - 0 if not compute_param_grad else DB_scratch.stride(0), - dY_grouped, - dY_grouped.stride(0), - dY_grouped.stride(1), - n_rows, + _DW_partial, + _DB_partial, + dY, + batch_size, hidden_size, channels_per_group, num_groups, - SINGLE_CHANNEL_TILE=single_channel_tile, - COMPUTE_PARAM_GRAD=compute_param_grad, - BLOCK_SIZE_M=BLOCK_SIZE_M, - BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE=BLOCK_SIZE, + dtype=triton_dtype, + MAX_CHUNK_SIZE=64 if num_channels >= 512 else 16, ) - # Precision note: - # - In-kernel atomic_add on floating-point values is order-dependent under parallel - # scheduling (non-associative summation), which can introduce run-to-run numerical - # differences in DW/DB for contention-heavy shapes. - # - Host-side dense reduction provides a more stable accumulation pattern for these - # difficult layouts. - if compute_param_grad: - DW = DW_scratch.sum(dim=0).to(W.dtype) - DB = DB_scratch.sum(dim=0).to(W.dtype) - else: - # Fallback path to avoid severe atomic contention when SINGLE_CHANNEL_TILE=False. - # Layout: [B, G, hidden_size] -> [B, G, C_per_G, hidden_per_channel] - X4 = X_grouped.reshape(batch_size, num_groups, channels_per_group, hidden_size_per_channel).to(torch.float32) - dY4 = dY_grouped.reshape(batch_size, num_groups, channels_per_group, hidden_size_per_channel).to(torch.float32) - mean4 = Mean.reshape(batch_size, num_groups, 1, 1).to(torch.float32) - rstd4 = RSTD.reshape(batch_size, num_groups, 1, 1).to(torch.float32) - - x_hat4 = (X4 - mean4) * rstd4 - DW = (dY4 * x_hat4).sum(dim=(0, 3)).reshape(-1).to(W.dtype) - DB = dY4.sum(dim=(0, 3)).reshape(-1).to(W.dtype) + DW = _DW_partial.sum(dim=(0, 1)).to(W.dtype) + DB = _DB_partial.sum(dim=(0, 1)).to(B.dtype) return DX.view(*shape), DW, DB