|
13 | 13 | # ----------------------------------------------------------------------------- |
14 | 14 |
|
15 | 15 |
|
| 16 | +MAX_FUSED_SIZE = 16384 |
| 17 | + |
| 18 | + |
16 | 19 | @triton.jit |
17 | 20 | def _group_norm_forward_kernel( |
18 | 21 | Y_ptr, # pointer to output, shape (B, G, hidden_size) |
@@ -109,139 +112,107 @@ def _group_norm_forward_kernel( |
109 | 112 |
|
110 | 113 | @triton.jit |
111 | 114 | def _group_norm_backward_kernel( |
112 | | - X_ptr, # pointer to input, shape (B, G, hidden_size) |
113 | | - X_row_stride, # stride of each batch row in X |
114 | | - X_col_stride, # stride of each group row in X |
115 | | - W_ptr, # pointer to affine scale weights, shape (C) |
116 | | - Mean_ptr, # pointer to saved group mean, shape (B, G) |
117 | | - Mean_row_stride, # stride of each batch row in Mean |
118 | | - Mean_col_stride, # stride of each group row in Mean |
119 | | - RSTD_ptr, # pointer to saved reciprocal std, shape (B, G) |
120 | | - DX_ptr, # pointer to input gradients, shape (B, G, hidden_size) |
121 | | - DW_scratch_ptr, # pointer to scratch buffer for dW partial sums, shape (grid, C) |
122 | | - DW_scratch_stride, # row stride for DW_scratch |
123 | | - DB_scratch_ptr, # pointer to scratch buffer for dB partial sums, shape (grid, C) |
124 | | - DB_scratch_stride, # row stride for DB_scratch |
125 | | - DY_ptr, # pointer to upstream gradients, shape (B, G, hidden_size) |
126 | | - DY_row_stride, # stride of each batch row in DY |
127 | | - DY_col_stride, # stride of each group row in DY |
128 | | - n_rows, # total logical rows = B * G |
129 | | - hidden_size, |
130 | | - channels_per_group, |
131 | | - num_groups, |
132 | | - SINGLE_CHANNEL_TILE: tl.constexpr, |
133 | | - COMPUTE_PARAM_GRAD: tl.constexpr, |
134 | | - BLOCK_SIZE_M: tl.constexpr, |
135 | | - BLOCK_SIZE_N: tl.constexpr, |
| 115 | + X_ptr, |
| 116 | + W_ptr, |
| 117 | + Mean_ptr, |
| 118 | + RSTD_ptr, |
| 119 | + DX_ptr, |
| 120 | + DW_partial_ptr, |
| 121 | + DB_partial_ptr, |
| 122 | + UPSTREAM_ptr, |
| 123 | + batch_size, |
| 124 | + hidden_size: tl.constexpr, |
| 125 | + channels_per_group: tl.constexpr, |
| 126 | + num_groups: tl.constexpr, |
| 127 | + BLOCK_SIZE: tl.constexpr, |
| 128 | + dtype: tl.constexpr, |
| 129 | + MAX_CHUNK_SIZE: tl.constexpr = 32, |
136 | 130 | ): |
137 | | - pid = tl.program_id(0) |
138 | | - num_progs = tl.num_programs(0) |
| 131 | + prog_id = tl.program_id(0) |
| 132 | + num_programs = tl.num_programs(0) |
| 133 | + total_tasks = num_groups * batch_size |
139 | 134 |
|
140 | | - grid_m = tl.cdiv(n_rows, BLOCK_SIZE_M) |
141 | | - num_col_blocks = tl.cdiv(hidden_size, BLOCK_SIZE_N) |
142 | | - hidden_size_per_channel = hidden_size // channels_per_group |
143 | | - N_inv = 1.0 / hidden_size |
144 | | - row_offsets = tl.arange(0, BLOCK_SIZE_M) |
145 | | - col_offsets_base = tl.arange(0, BLOCK_SIZE_N) |
| 135 | + for task_id in tl.range(prog_id, total_tasks, num_programs): |
| 136 | + batch_idx = task_id // num_groups |
| 137 | + group_idx = task_id % num_groups |
146 | 138 |
|
147 | | - if COMPUTE_PARAM_GRAD: |
148 | | - DW_scratch_base = DW_scratch_ptr + pid * DW_scratch_stride |
149 | | - DB_scratch_base = DB_scratch_ptr + pid * DB_scratch_stride |
| 139 | + num_channels = num_groups * channels_per_group |
| 140 | + X_row_stride = num_channels * hidden_size |
150 | 141 |
|
151 | | - # Persistent-program loop over row tiles. |
152 | | - for block_m in tl.range(pid, grid_m, num_progs): |
153 | | - row_idx = block_m * BLOCK_SIZE_M + row_offsets |
154 | | - row_mask = row_idx < n_rows |
155 | | - batch_idx = row_idx // num_groups |
156 | | - group_idx = row_idx % num_groups |
| 142 | + X_ptr_task = X_ptr + batch_idx * X_row_stride |
| 143 | + DX_ptr_task = DX_ptr + batch_idx * X_row_stride |
| 144 | + UPSTREAM_ptr_task = UPSTREAM_ptr + batch_idx * X_row_stride |
157 | 145 |
|
158 | | - mean = tl.load( |
159 | | - Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, |
160 | | - mask=row_mask, |
161 | | - other=0.0, |
162 | | - ).to(tl.float32) |
163 | | - rstd = tl.load( |
164 | | - RSTD_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, |
165 | | - mask=row_mask, |
166 | | - other=0.0, |
167 | | - ).to(tl.float32) |
168 | | - |
169 | | - sum_x_hat_wdy = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) |
170 | | - sum_wdy = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) |
171 | | - |
172 | | - # Pass 1: compute row-wise reduction terms (c1, c2). |
173 | | - for cb in range(num_col_blocks): |
174 | | - col_offsets = cb * BLOCK_SIZE_N + col_offsets_base |
175 | | - col_mask = col_offsets < hidden_size |
176 | | - mask = row_mask[:, None] & col_mask[None, :] |
| 146 | + mean = tl.load(Mean_ptr + batch_idx * num_groups + group_idx) |
| 147 | + rstd = tl.load(RSTD_ptr + batch_idx * num_groups + group_idx) |
177 | 148 |
|
178 | | - X_ptrs = ( |
179 | | - X_ptr + batch_idx[:, None] * X_row_stride + group_idx[:, None] * X_col_stride + col_offsets[None, :] |
180 | | - ) |
181 | | - DY_ptrs = ( |
182 | | - DY_ptr + batch_idx[:, None] * DY_row_stride + group_idx[:, None] * DY_col_stride + col_offsets[None, :] |
183 | | - ) |
184 | | - X_block = tl.load(X_ptrs, mask=mask, other=0.0).to(tl.float32) |
185 | | - DY_block = tl.load(DY_ptrs, mask=mask, other=0.0).to(tl.float32) |
| 149 | + c1 = 0.0 |
| 150 | + c2 = 0.0 |
| 151 | + block_range = tl.arange(0, BLOCK_SIZE) |
186 | 152 |
|
187 | | - if SINGLE_CHANNEL_TILE: |
188 | | - local_channel = (cb * BLOCK_SIZE_N) // hidden_size_per_channel |
189 | | - global_channel = group_idx * channels_per_group + local_channel |
190 | | - W_block = tl.load(W_ptr + global_channel, mask=row_mask, other=0.0).to(tl.float32)[:, None] |
191 | | - else: |
192 | | - local_channel = col_offsets // hidden_size_per_channel |
193 | | - global_channel = group_idx[:, None] * channels_per_group + local_channel[None, :] |
194 | | - W_block = tl.load(W_ptr + global_channel, mask=mask, other=0.0).to(tl.float32) |
| 153 | + scratch_base = batch_idx * (num_groups * num_channels) + group_idx * num_channels |
| 154 | + group_ch_start = group_idx * channels_per_group |
195 | 155 |
|
196 | | - x_hat = (X_block - mean[:, None]) * rstd[:, None] |
197 | | - wdy = W_block * DY_block |
198 | | - sum_x_hat_wdy += tl.sum(tl.where(mask, x_hat * wdy, 0.0), axis=1) |
199 | | - sum_wdy += tl.sum(tl.where(mask, wdy, 0.0), axis=1) |
| 156 | + neg_mean = -mean |
| 157 | + inv_N = 1.0 / (hidden_size * channels_per_group) |
200 | 158 |
|
201 | | - c1 = sum_x_hat_wdy * N_inv |
202 | | - c2 = sum_wdy * N_inv |
| 159 | + CHUNK_SIZE = tl.minimum(channels_per_group, MAX_CHUNK_SIZE) |
| 160 | + num_chunks = (channels_per_group + CHUNK_SIZE - 1) // CHUNK_SIZE |
203 | 161 |
|
204 | | - # Pass 2: compute DX and optionally accumulate DW/DB. |
205 | | - # COMPUTE_PARAM_GRAD=False is used to skip expensive atomics in cases |
206 | | - # where host-side dense reduction is faster/more stable. |
207 | | - for cb in range(num_col_blocks): |
208 | | - col_offsets = cb * BLOCK_SIZE_N + col_offsets_base |
209 | | - col_mask = col_offsets < hidden_size |
210 | | - mask = row_mask[:, None] & col_mask[None, :] |
| 162 | + for chunk_idx in tl.range(0, num_chunks): |
| 163 | + chunk_start = chunk_idx * CHUNK_SIZE |
| 164 | + chunk_end = tl.minimum(chunk_start + CHUNK_SIZE, channels_per_group) |
211 | 165 |
|
212 | | - X_ptrs = ( |
213 | | - X_ptr + batch_idx[:, None] * X_row_stride + group_idx[:, None] * X_col_stride + col_offsets[None, :] |
214 | | - ) |
215 | | - DY_ptrs = ( |
216 | | - DY_ptr + batch_idx[:, None] * DY_row_stride + group_idx[:, None] * DY_col_stride + col_offsets[None, :] |
217 | | - ) |
218 | | - X_block = tl.load(X_ptrs, mask=mask, other=0.0).to(tl.float32) |
219 | | - DY_block = tl.load(DY_ptrs, mask=mask, other=0.0).to(tl.float32) |
| 166 | + for local_ch in tl.range(chunk_start, chunk_end): |
| 167 | + W = tl.load(W_ptr + group_ch_start + local_ch) |
| 168 | + channel_base = (group_ch_start + local_ch) * hidden_size |
| 169 | + dW = 0.0 |
| 170 | + dB = 0.0 |
220 | 171 |
|
221 | | - if SINGLE_CHANNEL_TILE: |
222 | | - local_channel = (cb * BLOCK_SIZE_N) // hidden_size_per_channel |
223 | | - global_channel = group_idx * channels_per_group + local_channel |
224 | | - W_block = tl.load(W_ptr + global_channel, mask=row_mask, other=0.0).to(tl.float32)[:, None] |
225 | | - else: |
226 | | - local_channel = col_offsets // hidden_size_per_channel |
227 | | - global_channel = group_idx[:, None] * channels_per_group + local_channel[None, :] |
228 | | - W_block = tl.load(W_ptr + global_channel, mask=mask, other=0.0).to(tl.float32) |
| 172 | + for i in tl.range(0, hidden_size, BLOCK_SIZE, num_stages=8): |
| 173 | + offsets = i + block_range |
| 174 | + mask = offsets < hidden_size |
| 175 | + X = tl.load(X_ptr_task + channel_base + offsets, mask=mask, other=0.0) |
| 176 | + dy = tl.load(UPSTREAM_ptr_task + channel_base + offsets, mask=mask, other=0.0) |
229 | 177 |
|
230 | | - x_hat = (X_block - mean[:, None]) * rstd[:, None] |
231 | | - wdy = W_block * DY_block |
232 | | - DX_block = (wdy - (x_hat * c1[:, None] + c2[:, None])) * rstd[:, None] |
| 178 | + x_hat = (X + neg_mean) * rstd |
233 | 179 |
|
234 | | - DX_ptrs = ( |
235 | | - DX_ptr + batch_idx[:, None] * X_row_stride + group_idx[:, None] * X_col_stride + col_offsets[None, :] |
236 | | - ) |
237 | | - tl.store(DX_ptrs, DX_block.to(X_ptr.dtype.element_ty), mask=mask) |
| 180 | + dy_float = dy.to(tl.float32) |
| 181 | + dy_xh = dy_float * x_hat.to(tl.float32) |
| 182 | + |
| 183 | + tile_xh_dy = tl.sum(dy_xh) |
| 184 | + tile_dy = tl.sum(dy_float) |
| 185 | + dW += tile_xh_dy |
| 186 | + dB += tile_dy |
| 187 | + |
| 188 | + c1 += W * tile_xh_dy |
| 189 | + c2 += W * tile_dy |
| 190 | + |
| 191 | + tl.store(DW_partial_ptr + scratch_base + group_ch_start + local_ch, dW) |
| 192 | + tl.store(DB_partial_ptr + scratch_base + group_ch_start + local_ch, dB) |
| 193 | + |
| 194 | + c1 = c1 * inv_N |
| 195 | + c2 = c2 * inv_N |
| 196 | + c1_rstd2 = c1 * rstd * rstd |
| 197 | + c2_rstd = c2 * rstd |
| 198 | + bias = mean * c1_rstd2 - c2_rstd |
238 | 199 |
|
239 | | - if COMPUTE_PARAM_GRAD: |
240 | | - if SINGLE_CHANNEL_TILE: |
241 | | - dW_partial = tl.sum(tl.where(mask, DY_block * x_hat, 0.0), axis=1) |
242 | | - dB_partial = tl.sum(tl.where(mask, DY_block, 0.0), axis=1) |
243 | | - tl.atomic_add(DW_scratch_base + global_channel, dW_partial, mask=row_mask) |
244 | | - tl.atomic_add(DB_scratch_base + global_channel, dB_partial, mask=row_mask) |
| 200 | + for chunk_idx in tl.range(0, num_chunks): |
| 201 | + chunk_start = chunk_idx * CHUNK_SIZE |
| 202 | + chunk_end = tl.minimum(chunk_start + CHUNK_SIZE, channels_per_group) |
| 203 | + |
| 204 | + for local_ch in tl.range(chunk_start, chunk_end): |
| 205 | + W = tl.load(W_ptr + group_ch_start + local_ch) |
| 206 | + W_rstd = W * rstd |
| 207 | + channel_base = (group_ch_start + local_ch) * hidden_size |
| 208 | + |
| 209 | + for i in tl.range(0, hidden_size, BLOCK_SIZE, num_stages=8): |
| 210 | + offsets = i + block_range |
| 211 | + mask = offsets < hidden_size |
| 212 | + X = tl.load(X_ptr_task + channel_base + offsets, mask=mask, other=0.0) |
| 213 | + dy = tl.load(UPSTREAM_ptr_task + channel_base + offsets, mask=mask, other=0.0) |
| 214 | + dx = W_rstd * dy - X * c1_rstd2 + bias |
| 215 | + tl.store(DX_ptr_task + channel_base + offsets, dx.to(dtype), mask=mask) |
245 | 216 |
|
246 | 217 |
|
247 | 218 | # ----------------------------------------------------------------------------- |
@@ -341,88 +312,63 @@ def group_norm_forward(X, num_channels, num_groups, W, B, eps): |
341 | 312 | def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups): |
342 | 313 | shape = dY.shape |
343 | 314 | batch_size = shape[0] |
| 315 | + hidden_size = dY.shape[-1] |
344 | 316 | channels_per_group = num_channels // num_groups |
345 | | - X_grouped = X.view(batch_size, num_groups, -1) |
346 | | - dY_grouped = dY.view(batch_size, num_groups, -1) |
347 | | - hidden_size = dY_grouped.shape[-1] |
348 | | - hidden_size_per_channel = hidden_size // channels_per_group |
349 | | - n_rows = batch_size * num_groups |
350 | 317 |
|
351 | | - BLOCK_SIZE_N = min(128, triton.next_power_of_2(hidden_size)) |
352 | | - BLOCK_SIZE_M = get_optimal_block_size( |
353 | | - n_rows, |
354 | | - X.element_size(), |
355 | | - BLOCK_SIZE_N, |
356 | | - is_backward=True, |
| 318 | + dY = dY.view(batch_size, num_groups, -1) |
| 319 | + DX = torch.empty( |
| 320 | + (batch_size, num_groups, hidden_size * channels_per_group), |
| 321 | + dtype=X.dtype, |
| 322 | + device=X.device, |
357 | 323 | ) |
358 | 324 |
|
359 | | - # Same condition as forward: |
360 | | - # if true, each BLOCK_SIZE_N tile maps cleanly to one channel segment. |
361 | | - single_channel_tile = BLOCK_SIZE_N <= hidden_size_per_channel and hidden_size_per_channel % BLOCK_SIZE_N == 0 |
| 325 | + _DW_partial = torch.zeros( |
| 326 | + (batch_size, num_groups, num_channels), |
| 327 | + dtype=torch.float32, |
| 328 | + device=W.device, |
| 329 | + ) |
| 330 | + _DB_partial = torch.zeros( |
| 331 | + (batch_size, num_groups, num_channels), |
| 332 | + dtype=torch.float32, |
| 333 | + device=B.device, |
| 334 | + ) |
362 | 335 |
|
363 | | - num_cores = get_npu_core_count() |
364 | | - grid = min(num_cores, triton.cdiv(n_rows, BLOCK_SIZE_M)) |
365 | | - # For non-single-channel tiles, per-element atomic updates are costly. |
366 | | - # In that case, kernel computes DX only and DW/DB are reduced on host side. |
367 | | - compute_param_grad = single_channel_tile |
368 | | - |
369 | | - DX = torch.empty((batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device) |
370 | | - if compute_param_grad: |
371 | | - DW_scratch = torch.zeros((grid, num_channels), dtype=torch.float32, device=W.device) |
372 | | - DB_scratch = torch.zeros((grid, num_channels), dtype=torch.float32, device=W.device) |
373 | | - else: |
374 | | - # Not used when COMPUTE_PARAM_GRAD=False. |
375 | | - # Intentionally set to None to enforce fail-fast behavior if accidentally accessed. |
376 | | - DW_scratch = None |
377 | | - DB_scratch = None |
378 | | - |
379 | | - _group_norm_backward_kernel[(grid,)]( |
380 | | - X_grouped, |
381 | | - X_grouped.stride(0), |
382 | | - X_grouped.stride(1), |
| 336 | + element_size = dY.element_size() |
| 337 | + vv_alignment = 32 |
| 338 | + required_elem_alignment = vv_alignment // element_size |
| 339 | + |
| 340 | + BLOCK_SIZE = 512 |
| 341 | + BLOCK_SIZE = max(BLOCK_SIZE, required_elem_alignment) |
| 342 | + BLOCK_SIZE = min(BLOCK_SIZE, MAX_FUSED_SIZE) |
| 343 | + |
| 344 | + triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16 |
| 345 | + |
| 346 | + aicore_num = 48 |
| 347 | + total_tasks = batch_size * num_groups |
| 348 | + grid_size = min(aicore_num, total_tasks) |
| 349 | + grid_size = max(grid_size, 1) |
| 350 | + grid = (grid_size,) |
| 351 | + |
| 352 | + _group_norm_backward_kernel[grid]( |
| 353 | + X, |
383 | 354 | W, |
384 | 355 | Mean, |
385 | | - Mean.stride(0), |
386 | | - Mean.stride(1), |
387 | 356 | RSTD, |
388 | 357 | DX, |
389 | | - DW_scratch, |
390 | | - 0 if not compute_param_grad else DW_scratch.stride(0), |
391 | | - DB_scratch, |
392 | | - 0 if not compute_param_grad else DB_scratch.stride(0), |
393 | | - dY_grouped, |
394 | | - dY_grouped.stride(0), |
395 | | - dY_grouped.stride(1), |
396 | | - n_rows, |
| 358 | + _DW_partial, |
| 359 | + _DB_partial, |
| 360 | + dY, |
| 361 | + batch_size, |
397 | 362 | hidden_size, |
398 | 363 | channels_per_group, |
399 | 364 | num_groups, |
400 | | - SINGLE_CHANNEL_TILE=single_channel_tile, |
401 | | - COMPUTE_PARAM_GRAD=compute_param_grad, |
402 | | - BLOCK_SIZE_M=BLOCK_SIZE_M, |
403 | | - BLOCK_SIZE_N=BLOCK_SIZE_N, |
| 365 | + BLOCK_SIZE=BLOCK_SIZE, |
| 366 | + dtype=triton_dtype, |
| 367 | + MAX_CHUNK_SIZE=32, |
404 | 368 | ) |
405 | 369 |
|
406 | | - # Precision note: |
407 | | - # - In-kernel atomic_add on floating-point values is order-dependent under parallel |
408 | | - # scheduling (non-associative summation), which can introduce run-to-run numerical |
409 | | - # differences in DW/DB for contention-heavy shapes. |
410 | | - # - Host-side dense reduction provides a more stable accumulation pattern for these |
411 | | - # difficult layouts. |
412 | | - if compute_param_grad: |
413 | | - DW = DW_scratch.sum(dim=0).to(W.dtype) |
414 | | - DB = DB_scratch.sum(dim=0).to(W.dtype) |
415 | | - else: |
416 | | - # Fallback path to avoid severe atomic contention when SINGLE_CHANNEL_TILE=False. |
417 | | - # Layout: [B, G, hidden_size] -> [B, G, C_per_G, hidden_per_channel] |
418 | | - X4 = X_grouped.reshape(batch_size, num_groups, channels_per_group, hidden_size_per_channel).to(torch.float32) |
419 | | - dY4 = dY_grouped.reshape(batch_size, num_groups, channels_per_group, hidden_size_per_channel).to(torch.float32) |
420 | | - mean4 = Mean.reshape(batch_size, num_groups, 1, 1).to(torch.float32) |
421 | | - rstd4 = RSTD.reshape(batch_size, num_groups, 1, 1).to(torch.float32) |
422 | | - |
423 | | - x_hat4 = (X4 - mean4) * rstd4 |
424 | | - DW = (dY4 * x_hat4).sum(dim=(0, 3)).reshape(-1).to(W.dtype) |
425 | | - DB = dY4.sum(dim=(0, 3)).reshape(-1).to(W.dtype) |
| 370 | + DW = _DW_partial.sum(dim=(0, 1)).to(W.dtype) |
| 371 | + DB = _DB_partial.sum(dim=(0, 1)).to(B.dtype) |
426 | 372 |
|
427 | 373 | return DX.view(*shape), DW, DB |
428 | 374 |
|
|
0 commit comments