Skip to content

Commit 87d7057

Browse files
authored
[Feature] Add elementwise_affine argument to LigerRMSNorm (#989)
## Summary This PR adds the `elementwise_affine` argument to `LigerRMSNorm` to align its API with `torch.nn.RMSNorm`. This allows users to instantiate the normalization layer without learnable parameters. ## Details - Added `elementwise_affine` (defaulting to `True`) to the `__init__` method of `LigerRMSNorm`. - When `elementwise_affine` is set to `False`, `self.weight` is registered as `None`, preventing the allocation of unnecessary parameters. - Updated `extra_repr` to include the `elementwise_affine` status in the printed model structure. - The parameter is added at the end of the argument list to maintain backward compatibility for subclasses (e.g., `LigerRMSNormForGemma`) that rely on positional arguments during `super().__init__` calls. ## Testing Done - Hardware Type: NVIDIA A100-SXM4-80GB - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence
1 parent 77949e0 commit 87d7057

File tree

3 files changed

+178
-63
lines changed

3 files changed

+178
-63
lines changed

src/liger_kernel/ops/rms_norm.py

Lines changed: 111 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def _rms_norm_forward_kernel(
5454
eps,
5555
offset,
5656
casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
57+
elementwise_affine: tl.constexpr,
5758
BLOCK_SIZE: tl.constexpr,
5859
):
5960
"""
@@ -75,15 +76,17 @@ def _rms_norm_forward_kernel(
7576

7677
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
7778
X_row_dtype = X_row.dtype
78-
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
79+
if elementwise_affine:
80+
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
7981

8082
# On Llama, only rstd is computed on fp32
8183
if casting_mode == _CASTING_MODE_LLAMA:
8284
X_row = X_row.to(tl.float32)
8385

8486
# Gemma computes everything on fp32, and then casts back the output to the original dtype
8587
if casting_mode == _CASTING_MODE_GEMMA:
86-
W_row = W_row.to(tl.float32)
88+
if elementwise_affine:
89+
W_row = W_row.to(tl.float32)
8790
X_row = X_row.to(tl.float32)
8891

8992
if casting_mode == _CASTING_MODE_NONE:
@@ -104,7 +107,10 @@ def _rms_norm_forward_kernel(
104107
if casting_mode == _CASTING_MODE_LLAMA:
105108
X_row = X_row.to(X_row_dtype)
106109

107-
Y_row = X_row * (offset + W_row)
110+
if elementwise_affine:
111+
Y_row = X_row * (offset + W_row)
112+
else:
113+
Y_row = X_row
108114

109115
if casting_mode == _CASTING_MODE_GEMMA:
110116
Y_row = Y_row.to(X_row_dtype)
@@ -132,6 +138,7 @@ def _rms_norm_backward_kernel(
132138
offset,
133139
rows_per_program,
134140
casting_mode: tl.constexpr,
141+
elementwise_affine: tl.constexpr,
135142
BLOCK_SIZE: tl.constexpr,
136143
):
137144
"""
@@ -145,16 +152,18 @@ def _rms_norm_backward_kernel(
145152
col_offsets = tl.arange(0, BLOCK_SIZE)
146153
mask = col_offsets < n_cols
147154

148-
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
155+
if elementwise_affine:
156+
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
149157

150158
dY_ptr += row_start * dY_row_stride
151159
dX_ptr += row_start * dX_row_stride
152160

153161
X_ptr += row_start * X_row_stride
154162
RSTD_ptr += row_start
155163

156-
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
157-
W_row = W_row + offset
164+
if elementwise_affine:
165+
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
166+
W_row = W_row + offset
158167

159168
for _ in range(row_start, row_end):
160169
dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0)
@@ -167,24 +176,34 @@ def _rms_norm_backward_kernel(
167176

168177
# Different bacward graphs for different casting modes
169178
if casting_mode == _CASTING_MODE_LLAMA:
170-
m = (dY_row * W_row).to(tl.float32)
179+
if elementwise_affine:
180+
m = (dY_row * W_row).to(tl.float32)
181+
else:
182+
m = dY_row.to(tl.float32)
171183

172184
elif casting_mode == _CASTING_MODE_GEMMA:
173185
dY_row = dY_row.to(tl.float32)
174-
m = dY_row * W_row
186+
if elementwise_affine:
187+
m = dY_row * W_row
188+
else:
189+
m = dY_row
175190
else:
176-
m = dY_row * W_row
191+
if elementwise_affine:
192+
m = dY_row * W_row
193+
else:
194+
m = dY_row
177195

178196
dX_row = rstd_row * m
179197

180198
dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
181199

182-
# calculate the gradient of W
183-
if casting_mode == _CASTING_MODE_LLAMA:
184-
dW_row += dY_row * (X_row * rstd_row).to(X_dtype)
185-
else:
186-
# here X_row is already in fp32 (see previous if block)
187-
dW_row += dY_row * (X_row * rstd_row)
200+
if elementwise_affine:
201+
# calculate the gradient of W
202+
if casting_mode == _CASTING_MODE_LLAMA:
203+
dW_row += dY_row * (X_row * rstd_row).to(X_dtype)
204+
else:
205+
# here X_row is already in fp32 (see previous if block)
206+
dW_row += dY_row * (X_row * rstd_row)
188207

189208
tl.store(dX_ptr + col_offsets, dX_row.to(X_dtype), mask=mask)
190209

@@ -193,7 +212,8 @@ def _rms_norm_backward_kernel(
193212
X_ptr += X_row_stride
194213
RSTD_ptr += RSTD_row_stride
195214

196-
tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
215+
if elementwise_affine:
216+
tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
197217

198218

199219
@triton.jit
@@ -211,6 +231,7 @@ def _block_rms_norm_forward_kernel(
211231
eps,
212232
offset,
213233
casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
234+
elementwise_affine: tl.constexpr,
214235
BLOCK_SIZE: tl.constexpr,
215236
BLOCK_ROW: tl.constexpr,
216237
):
@@ -234,15 +255,17 @@ def _block_rms_norm_forward_kernel(
234255
other=0,
235256
)
236257
X_row_dtype = X_row.dtype
237-
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0)
258+
if elementwise_affine:
259+
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0)
238260

239261
# On Llama, only rstd is computed on fp32
240262
if casting_mode == _CASTING_MODE_LLAMA:
241263
X_row = X_row.to(tl.float32)
242264

243265
# Gemma computes everything on fp32, and then casts back the output to the original dtype
244266
if casting_mode == _CASTING_MODE_GEMMA:
245-
W_row = W_row.to(tl.float32)
267+
if elementwise_affine:
268+
W_row = W_row.to(tl.float32)
246269
X_row = X_row.to(tl.float32)
247270

248271
if casting_mode == _CASTING_MODE_NONE:
@@ -263,7 +286,10 @@ def _block_rms_norm_forward_kernel(
263286
if casting_mode == _CASTING_MODE_LLAMA:
264287
X_row = X_row.to(X_row_dtype)
265288

266-
Y_row = X_row * (offset + W_row)[None, :]
289+
if elementwise_affine:
290+
Y_row = X_row * (offset + W_row)[None, :]
291+
else:
292+
Y_row = X_row
267293

268294
if casting_mode == _CASTING_MODE_GEMMA:
269295
Y_row = Y_row.to(X_row_dtype)
@@ -294,6 +320,7 @@ def _block_rms_norm_backward_kernel(
294320
n_cols,
295321
offset,
296322
casting_mode: tl.constexpr,
323+
elementwise_affine: tl.constexpr,
297324
BLOCK_SIZE: tl.constexpr,
298325
BLOCK_ROW: tl.constexpr,
299326
):
@@ -308,10 +335,11 @@ def _block_rms_norm_backward_kernel(
308335
col_offsets = tl.arange(0, BLOCK_SIZE)
309336
col_mask = col_offsets < n_cols
310337

311-
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
338+
if elementwise_affine:
339+
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
312340

313-
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
314-
W_row = W_row + offset
341+
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
342+
W_row = W_row + offset
315343

316344
for start in range(pid * BLOCK_ROW, n_rows, NUM_SMS * BLOCK_ROW):
317345
row_idx = start + tl.arange(0, BLOCK_ROW)
@@ -334,35 +362,45 @@ def _block_rms_norm_backward_kernel(
334362

335363
# Different bacward graphs for different casting modes
336364
if casting_mode == _CASTING_MODE_LLAMA:
337-
m = (dY_row * W_row[None, :]).to(tl.float32)
365+
if elementwise_affine:
366+
m = (dY_row * W_row[None, :]).to(tl.float32)
367+
else:
368+
m = dY_row.to(tl.float32)
338369

339370
elif casting_mode == _CASTING_MODE_GEMMA:
340371
dY_row = dY_row.to(tl.float32)
341-
m = dY_row * W_row[None, :]
372+
if elementwise_affine:
373+
m = dY_row * W_row[None, :]
374+
else:
375+
m = dY_row
342376
else:
343-
m = dY_row * W_row[None, :]
377+
if elementwise_affine:
378+
m = dY_row * W_row[None, :]
379+
else:
380+
m = dY_row
344381

345382
dX_row = rstd_row[:, None] * m
346383

347384
dX_row += (rstd_row[:, None]) * (
348385
-(1 / n_cols) * (rstd_row * rstd_row * tl.sum(m * X_row, axis=1))[:, None] * X_row
349386
)
350387

351-
# calculate the gradient of W
352-
if casting_mode == _CASTING_MODE_LLAMA:
353-
# TODO(tcc): use tl.sum(..., dtype=tl.float32) once we upgrade to triton>=3.3.0
354-
dW_row += tl.sum((dY_row * (X_row * rstd_row[:, None]).to(X_dtype)).to(tl.float32), 0)
355-
else:
356-
# here X_row is already in fp32 (see previous if block)
357-
dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
388+
if elementwise_affine:
389+
if casting_mode == _CASTING_MODE_LLAMA:
390+
# TODO(tcc): use tl.sum(..., dtype=tl.float32) once we upgrade to triton>=3.3.0
391+
dW_row += tl.sum((dY_row * (X_row * rstd_row[:, None]).to(X_dtype)).to(tl.float32), 0)
392+
else:
393+
# here X_row is already in fp32 (see previous if block)
394+
dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
358395

359396
tl.store(
360397
dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :],
361398
dX_row,
362399
mask=row_mask[:, None] & col_mask[None, :],
363400
)
364401

365-
tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask)
402+
if elementwise_affine:
403+
tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask)
366404

367405

368406
_str_to_casting_mode = {
@@ -391,8 +429,14 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
391429
rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
392430
RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
393431

394-
# Check constraints.
395-
assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
432+
if W is not None:
433+
# Check constraints.
434+
assert X.shape[1] == W.shape[0], (
435+
"Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
436+
)
437+
elementwise_affine = True
438+
else:
439+
elementwise_affine = False
396440

397441
# XPU-specific optimization
398442
kernel_args = {}
@@ -405,13 +449,14 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
405449
X,
406450
X.stride(0),
407451
W,
408-
W.stride(0),
452+
W.stride(0) if elementwise_affine else 0,
409453
RSTD,
410454
RSTD.stride(0),
411455
n_cols,
412456
eps,
413457
offset,
414458
casting_mode,
459+
elementwise_affine=elementwise_affine,
415460
BLOCK_SIZE=BLOCK_SIZE,
416461
num_warps=num_warps,
417462
**kernel_args, # XPU-specific optimization
@@ -425,14 +470,15 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
425470
X,
426471
X.stride(0),
427472
W,
428-
W.stride(0),
473+
W.stride(0) if elementwise_affine else 0,
429474
RSTD,
430475
RSTD.stride(0),
431476
n_rows,
432477
n_cols,
433478
eps,
434479
offset,
435480
casting_mode,
481+
elementwise_affine=elementwise_affine,
436482
BLOCK_SIZE=BLOCK_SIZE,
437483
num_warps=num_warps,
438484
**kernel_args, # XPU-specific optimization
@@ -454,8 +500,13 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
454500
elif X.device.type == "npu":
455501
sm_count = get_npu_multi_processor_count()
456502

457-
# fp32 for numerical stability especially.
458-
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
503+
if W is not None:
504+
# fp32 for numerical stability especially.
505+
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
506+
elementwise_affine = True
507+
else:
508+
_dW = None
509+
elementwise_affine = False
459510

460511
if n_cols > BLOCK_SIZE:
461512
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
@@ -482,16 +533,17 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
482533
X.stride(0),
483534
torch_to_triton_dtype[X.dtype],
484535
W,
485-
W.stride(0),
536+
W.stride(0) if elementwise_affine else 0,
486537
RSTD,
487538
RSTD.stride(0),
488539
_dW,
489-
_dW.stride(0),
540+
_dW.stride(0) if elementwise_affine else 0,
490541
n_rows,
491542
n_cols,
492543
offset,
493544
rows_per_program,
494545
casting_mode,
546+
elementwise_affine=elementwise_affine,
495547
BLOCK_SIZE=BLOCK_SIZE,
496548
num_warps=num_warps,
497549
**kernel_args, # XPU-specific optimization
@@ -508,21 +560,26 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
508560
X.stride(0),
509561
torch_to_triton_dtype[X.dtype],
510562
W,
511-
W.stride(0),
563+
W.stride(0) if elementwise_affine else 0,
512564
RSTD,
513565
RSTD.stride(0),
514566
_dW,
515-
_dW.stride(0),
567+
_dW.stride(0) if elementwise_affine else 0,
516568
n_rows,
517569
n_cols,
518570
offset,
519571
casting_mode,
572+
elementwise_affine=elementwise_affine,
520573
BLOCK_SIZE=BLOCK_SIZE,
521574
num_warps=num_warps,
522575
**kernel_args, # XPU-specific optimization
523576
)
524577
dX = dX.view(*shape)
525-
dW = _dW.sum(dim=0).to(W.dtype)
578+
579+
if elementwise_affine:
580+
dW = _dW.sum(dim=0).to(W.dtype)
581+
else:
582+
dW = None
526583

527584
return dX, dW
528585

@@ -563,7 +620,11 @@ def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True, row
563620
ctx.row_mode = row_mode
564621
ctx.BLOCK_SIZE = BLOCK_SIZE
565622
ctx.num_warps = num_warps
566-
ctx.save_for_backward(X, W, RSTD)
623+
ctx.elementwise_affine = W is not None
624+
if W is not None:
625+
ctx.save_for_backward(X, W, RSTD)
626+
else:
627+
ctx.save_for_backward(X, RSTD)
567628
return Y
568629

569630
@staticmethod
@@ -572,7 +633,11 @@ def backward(ctx, dY):
572633
"""
573634
Y: (B, T, H) or (BxT, H)
574635
"""
575-
X, W, RSTD = ctx.saved_tensors
636+
if ctx.elementwise_affine:
637+
X, W, RSTD = ctx.saved_tensors
638+
else:
639+
X, RSTD = ctx.saved_tensors
640+
W = None
576641
dX, dW = rms_norm_backward(
577642
dY, X, W, RSTD, ctx.offset, ctx.casting_mode, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place, ctx.row_mode
578643
)

0 commit comments

Comments
 (0)