|
| 1 | +import math |
1 | 2 | import operator |
2 | 3 |
|
3 | 4 | import torch |
@@ -85,68 +86,87 @@ def _layer_norm_forward_kernel( |
85 | 86 | @triton.jit |
86 | 87 | def _layer_norm_backward_kernel( |
87 | 88 | X_ptr, # pointer to input, shape (n_rows, n_cols) |
| 89 | + stride_x, # stride of each row in input |
88 | 90 | W_ptr, # pointer to weights, shape (n_cols,) |
89 | 91 | Mean_ptr, # pointer to mean, shape (n_rows,) |
| 92 | + stride_mean, # stride of each row in mean |
90 | 93 | RSTD_ptr, # pointer to rstd, shape (n_rows,) |
| 94 | + stride_rstd, # stride of each row in rstd |
91 | 95 | DX_ptr, # pointer to input grad, shape (n_rows, n_cols) |
| 96 | + stride_dx, # stride of each row in input grad |
92 | 97 | DW_ptr, # pointer to weights grad, shape (n_cols,) |
| 98 | + stride_dw, # stride of each row in weights grad |
93 | 99 | DB_ptr, # pointer to bias grad, shape (n_cols,) |
| 100 | + stride_db, # stride of each row in bias grad |
94 | 101 | DY_ptr, # pointer to output grad, shape (n_rows, n_cols) |
95 | | - stride_x, # stride of each row in input |
96 | | - stride_dx, # stride of each row in input grad |
97 | 102 | stride_dy, # stride of each row in output grad |
| 103 | + n_rows, |
98 | 104 | n_cols, |
| 105 | + rows_per_program: tl.constexpr, |
99 | 106 | BLOCK_SIZE: tl.constexpr, |
100 | | - dtype: tl.constexpr, |
101 | | - atomic_dtype: tl.constexpr, |
102 | 107 | ): |
103 | 108 | """ |
104 | 109 | References: |
105 | 110 | https://arxiv.org/abs/1607.06450 |
106 | 111 | https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md |
107 | 112 | """ |
108 | | - row_idx = tl.program_id(0).to(tl.int64) |
| 113 | + row_block_id = tl.program_id(0).to(tl.int64) |
| 114 | + row_start = row_block_id * rows_per_program |
| 115 | + row_end = min((row_block_id + 1) * rows_per_program, n_rows) |
109 | 116 | cols = tl.arange(0, BLOCK_SIZE) |
110 | 117 | mask = cols < n_cols |
111 | 118 |
|
| 119 | + dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) |
| 120 | + db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) |
| 121 | + |
112 | 122 | # Pre-load weights once (same optimization as forward pass) |
113 | 123 | w = tl.load(W_ptr + cols, mask=mask, other=0.0) |
114 | 124 | w_f32 = w.to(tl.float32) |
115 | 125 |
|
116 | 126 | # Calculate pointers for this specific row |
117 | | - row_X_ptr = X_ptr + row_idx * stride_x |
118 | | - row_DX_ptr = DX_ptr + row_idx * stride_dx |
119 | | - row_DY_ptr = DY_ptr + row_idx * stride_dy |
120 | | - row_Mean_ptr = Mean_ptr + row_idx |
121 | | - row_RSTD_ptr = RSTD_ptr + row_idx |
122 | | - |
123 | | - # Load data for this row |
124 | | - x = tl.load(row_X_ptr + cols, mask=mask, other=0.0) |
125 | | - dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0) |
126 | | - mean = tl.load(row_Mean_ptr) |
127 | | - rstd = tl.load(row_RSTD_ptr) |
128 | | - |
129 | | - # Convert to fp32 for numerical stability |
130 | | - x_f32 = x.to(tl.float32) |
131 | | - dy_f32 = dy.to(tl.float32) |
132 | | - mean_f32 = mean.to(tl.float32) |
133 | | - rstd_f32 = rstd.to(tl.float32) |
134 | | - |
135 | | - # Compute backward pass for this row |
136 | | - x_hat = (x_f32 - mean_f32) * rstd_f32 |
137 | | - wdy = w_f32 * dy_f32 |
138 | | - c1 = tl.sum(x_hat * wdy, axis=0) / n_cols |
139 | | - c2 = tl.sum(wdy, axis=0) / n_cols |
140 | | - dx = (wdy - (x_hat * c1 + c2)) * rstd_f32 |
141 | | - |
142 | | - # Store input gradient |
143 | | - tl.store(row_DX_ptr + cols, dx.to(dtype), mask=mask) |
144 | | - |
145 | | - # Accumulate weight and bias gradients using atomic operations |
146 | | - dw = dy_f32 * x_hat |
147 | | - db = dy_f32 |
148 | | - tl.atomic_add(DW_ptr + cols, dw.to(atomic_dtype), mask=mask) |
149 | | - tl.atomic_add(DB_ptr + cols, db.to(atomic_dtype), mask=mask) |
| 127 | + row_X_ptr = X_ptr + row_start * stride_x |
| 128 | + row_DX_ptr = DX_ptr + row_start * stride_dx |
| 129 | + row_DY_ptr = DY_ptr + row_start * stride_dy |
| 130 | + row_Mean_ptr = Mean_ptr + row_start |
| 131 | + row_RSTD_ptr = RSTD_ptr + row_start |
| 132 | + |
| 133 | + for _ in range(row_start, row_end): |
| 134 | + # Load data for this row |
| 135 | + x = tl.load(row_X_ptr + cols, mask=mask, other=0.0) |
| 136 | + dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0) |
| 137 | + mean = tl.load(row_Mean_ptr) |
| 138 | + rstd = tl.load(row_RSTD_ptr) |
| 139 | + |
| 140 | + # Convert to fp32 for numerical stability |
| 141 | + x_f32 = x.to(tl.float32) |
| 142 | + dy_f32 = dy.to(tl.float32) |
| 143 | + mean_f32 = mean.to(tl.float32) |
| 144 | + rstd_f32 = rstd.to(tl.float32) |
| 145 | + |
| 146 | + # Compute backward pass for this row |
| 147 | + x_hat = (x_f32 - mean_f32) * rstd_f32 |
| 148 | + wdy = w_f32 * dy_f32 |
| 149 | + c1 = tl.sum(x_hat * wdy, axis=0) / n_cols |
| 150 | + c2 = tl.sum(wdy, axis=0) / n_cols |
| 151 | + dx = (wdy - (x_hat * c1 + c2)) * rstd_f32 |
| 152 | + |
| 153 | + # Store input gradient |
| 154 | + tl.store(row_DX_ptr + cols, dx, mask=mask) |
| 155 | + |
| 156 | + # Accumulate weight and bias gradients for this thread block's assigned rows |
| 157 | + dw = dy_f32 * x_hat |
| 158 | + db = dy_f32 |
| 159 | + dW_row += dw |
| 160 | + db_row += db |
| 161 | + |
| 162 | + row_X_ptr += stride_x |
| 163 | + row_DX_ptr += stride_dx |
| 164 | + row_DY_ptr += stride_dy |
| 165 | + row_Mean_ptr += stride_mean |
| 166 | + row_RSTD_ptr += stride_rstd |
| 167 | + |
| 168 | + tl.store(DW_ptr + row_block_id * stride_dw + cols, dW_row, mask=mask) |
| 169 | + tl.store(DB_ptr + row_block_id * stride_db + cols, db_row, mask=mask) |
150 | 170 |
|
151 | 171 |
|
152 | 172 | def layer_norm_forward(X, W, B, eps): |
@@ -228,60 +248,59 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD): |
228 | 248 | dY = dY.view(-1, dim) |
229 | 249 | n_rows, n_cols = dY.shape |
230 | 250 |
|
231 | | - # Allocate gradient tensors |
232 | | - DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) |
233 | | - # Use float32 for weight/bias gradients if bfloat16 (due to atomic_add limitation) |
234 | | - grad_dtype = torch.float32 if W.dtype == torch.bfloat16 else W.dtype |
235 | | - DW = torch.zeros(n_cols, dtype=grad_dtype, device=W.device) |
236 | | - DB = torch.zeros(n_cols, dtype=grad_dtype, device=W.device) |
| 251 | + sm_count = 1 |
| 252 | + if X.device.type == "cuda": |
| 253 | + sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count |
| 254 | + elif X.device.type == "xpu": |
| 255 | + sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count |
| 256 | + |
| 257 | + # fp32 for numerical stability especially. |
| 258 | + _DW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device) |
| 259 | + _DB = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device) |
237 | 260 |
|
238 | 261 | # Calculate optimal block size and warp configuration |
239 | 262 | BLOCK_SIZE, num_warps = calculate_settings(n_cols) |
240 | 263 | if n_cols > BLOCK_SIZE: |
241 | 264 | raise RuntimeError(f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}.") |
| 265 | + rows_per_program = math.ceil(n_rows / sm_count) |
| 266 | + grid = (sm_count,) |
242 | 267 |
|
243 | | - # Determine dtype for triton operations |
244 | | - triton_dtype = ( |
245 | | - tl.float32 |
246 | | - if X.dtype == torch.float32 |
247 | | - else tl.bfloat16 |
248 | | - if X.dtype == torch.bfloat16 |
249 | | - else tl.float16 |
250 | | - if X.dtype == torch.float16 |
251 | | - else tl.float32 # fallback |
252 | | - ) |
253 | | - |
254 | | - # Use float32 for atomic operations if bfloat16 is not supported |
255 | | - atomic_dtype = tl.float32 if triton_dtype == tl.bfloat16 else triton_dtype |
| 268 | + # Allocate gradient tensors |
| 269 | + DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) |
256 | 270 |
|
257 | 271 | kernel_args = {"num_warps": num_warps} |
258 | 272 | # XPU-specific optimization |
259 | 273 | if X.device.type == "xpu": |
260 | 274 | kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4}) |
261 | 275 |
|
262 | 276 | # Launch kernel with one thread block per row for optimal performance |
263 | | - grid = (n_rows,) |
264 | 277 | _layer_norm_backward_kernel[grid]( |
265 | 278 | X, |
| 279 | + X.stride(0), |
266 | 280 | W, |
267 | 281 | Mean, |
| 282 | + Mean.stride(0), |
268 | 283 | RSTD, |
| 284 | + RSTD.stride(0), |
269 | 285 | DX, |
270 | | - DW, |
271 | | - DB, |
272 | | - dY, |
273 | | - X.stride(0), |
274 | 286 | DX.stride(0), |
| 287 | + _DW, |
| 288 | + _DW.stride(0), |
| 289 | + _DB, |
| 290 | + _DB.stride(0), |
| 291 | + dY, |
275 | 292 | dY.stride(0), |
| 293 | + n_rows, |
276 | 294 | n_cols, |
| 295 | + rows_per_program=rows_per_program, |
277 | 296 | BLOCK_SIZE=BLOCK_SIZE, |
278 | | - dtype=triton_dtype, |
279 | | - atomic_dtype=atomic_dtype, |
280 | 297 | **kernel_args, |
281 | 298 | ) |
282 | 299 |
|
283 | 300 | DX = DX.view(*shape) |
284 | | - return DX, DW.to(W.dtype), DB.to(W.dtype) |
| 301 | + DW = _DW.sum(dim=0).to(W.dtype) |
| 302 | + DB = _DB.sum(dim=0).to(B.dtype) |
| 303 | + return DX, DW, DB |
285 | 304 |
|
286 | 305 |
|
287 | 306 | class LigerLayerNormFunction(torch.autograd.Function): |
|
0 commit comments