-
Notifications
You must be signed in to change notification settings - Fork 507
Expand file tree
/
Copy pathrms_norm.py
More file actions
732 lines (618 loc) · 24.5 KB
/
rms_norm.py
File metadata and controls
732 lines (618 loc) · 24.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
"""
This file incorporates code from Unsloth licensed under the Apache License, Version 2.0.
See the original Unsloth repository at https://github.com/unslothai/unsloth.
The following line
https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/rms_norm.py#L30
is based on code from Unsloth, located at:
https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
Modifications made by Yanning Chen, 2024.
"""
import math
import operator
import torch
import triton
import triton.language as tl
try:
from torch.distributed.tensor import Shard
_DTENSOR_AVAILABLE = True
except ImportError:
_DTENSOR_AVAILABLE = False
Shard = None
from liger_kernel.ops.utils import calculate_settings
from liger_kernel.ops.utils import compare_version
from liger_kernel.ops.utils import ensure_contiguous
from liger_kernel.ops.utils import get_npu_core_count
from liger_kernel.ops.utils import set_large_grf_mode
from liger_kernel.ops.utils import torch_to_triton_dtype
from liger_kernel.utils import is_npu_available
def _is_hidden_dim_sharded(dtensor: "torch.distributed.tensor.DTensor") -> bool:
"""
Check if the DTensor is sharded on the hidden dimension (last dimension).
This is used to determine whether we need to gather the full tensor for RMSNorm
computation (Tensor Parallel case) or can compute locally (Context Parallel case).
Args:
dtensor: A DTensor instance to check.
Returns:
True if the tensor is sharded on the hidden (last) dimension (TP case),
False otherwise (CP case - can compute locally).
"""
if not _DTENSOR_AVAILABLE or Shard is None:
return False
hidden_dim = dtensor.ndim - 1 # Last dimension is the hidden dimension
for placement in dtensor.placements:
if isinstance(placement, Shard) and placement.dim == hidden_dim:
return True
return False
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
try:
# typical import path with dispatch available
from triton.language.extra.libdevice import rsqrt
except ModuleNotFoundError:
# for working with NGC containers
from triton.language.extra.cuda.libdevice import rsqrt
else:
from triton.language.math import rsqrt
_CASTING_MODE_NONE: tl.constexpr = tl.constexpr(-1)
_CASTING_MODE_LLAMA: tl.constexpr = tl.constexpr(0)
_CASTING_MODE_GEMMA: tl.constexpr = tl.constexpr(1)
@triton.jit
def _rms_norm_forward_kernel(
Y_ptr,
Y_row_stride,
X_ptr,
X_row_stride,
W_ptr,
W_row_stride,
RSTD_ptr,
RSTD_row_stride,
n_cols,
eps,
offset,
casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
elementwise_affine: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N)
Reference:
1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
3. https://arxiv.org/pdf/1910.07467
"""
row_idx = tl.program_id(0).to(tl.int64)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
y_base = Y_ptr + row_idx * Y_row_stride
x_base = X_ptr + row_idx * X_row_stride
rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
X_row = tl.load(x_base + col_offsets, mask=mask, other=0)
X_row_dtype = X_row.dtype
if elementwise_affine:
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
# On Llama, only rstd is computed on fp32
if casting_mode == _CASTING_MODE_LLAMA:
X_row = X_row.to(tl.float32)
# Gemma computes everything on fp32, and then casts back the output to the original dtype
if casting_mode == _CASTING_MODE_GEMMA:
if elementwise_affine:
W_row = W_row.to(tl.float32)
X_row = X_row.to(tl.float32)
if casting_mode == _CASTING_MODE_NONE:
eps = eps.to(X_row_dtype)
offset = offset.to(X_row_dtype)
mean_square = tl.sum(X_row * X_row, axis=0) / n_cols
rstd = rsqrt(mean_square + eps)
# We can save time by caching rms with minimal memory overhead
# because rms is much smaller compared to X_row, as rms is for each row.
# However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
tl.store(rstd_base, rstd)
X_row = X_row * rstd
# On Llama, the multiplication with the weight is done on the original dtype
if casting_mode == _CASTING_MODE_LLAMA:
X_row = X_row.to(X_row_dtype)
if elementwise_affine:
Y_row = X_row * (offset + W_row)
else:
Y_row = X_row
if casting_mode == _CASTING_MODE_GEMMA:
Y_row = Y_row.to(X_row_dtype)
tl.store(y_base + col_offsets, Y_row, mask=mask)
@triton.jit
def _rms_norm_backward_kernel(
dY_ptr,
dY_row_stride,
dX_ptr,
dX_row_stride,
X_ptr,
X_row_stride,
X_dtype: tl.constexpr,
W_ptr,
W_row_stride,
RSTD_ptr,
RSTD_row_stride,
dW_ptr,
dW_row_stride,
n_rows,
n_cols,
offset,
rows_per_program,
casting_mode: tl.constexpr,
elementwise_affine: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
dw = sum(dy * (x / RMS)). summation over BxT dimension
"""
row_block_id = tl.program_id(0).to(tl.int64)
row_start = row_block_id * rows_per_program
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
if elementwise_affine:
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
if elementwise_affine:
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
W_row = W_row + offset
for row_idx in range(row_start, row_end):
dy_base = dY_ptr + row_idx * dY_row_stride
dx_base = dX_ptr + row_idx * dX_row_stride
x_base = X_ptr + row_idx * X_row_stride
rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
dY_row = tl.load(dy_base + col_offsets, mask=mask, other=0.0)
X_row = tl.load(x_base + col_offsets, mask=mask, other=0.0)
# Get cached rms
rstd_row = tl.load(rstd_base)
X_row = X_row.to(tl.float32)
# Different bacward graphs for different casting modes
if casting_mode == _CASTING_MODE_LLAMA:
if elementwise_affine:
m = (dY_row * W_row).to(tl.float32)
else:
m = dY_row.to(tl.float32)
elif casting_mode == _CASTING_MODE_GEMMA:
dY_row = dY_row.to(tl.float32)
if elementwise_affine:
m = dY_row * W_row
else:
m = dY_row
else:
if elementwise_affine:
m = dY_row * W_row
else:
m = dY_row
dX_row = rstd_row * m
dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
if elementwise_affine:
# calculate the gradient of W
if casting_mode == _CASTING_MODE_LLAMA:
dW_row += dY_row * (X_row * rstd_row).to(X_dtype)
else:
# here X_row is already in fp32 (see previous if block)
dW_row += dY_row * (X_row * rstd_row)
tl.store(dx_base + col_offsets, dX_row.to(X_dtype), mask=mask)
if elementwise_affine:
tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
@triton.jit
def _block_rms_norm_forward_kernel(
Y_ptr,
Y_row_stride,
X_ptr,
X_row_stride,
W_ptr,
W_row_stride,
RSTD_ptr,
RSTD_row_stride,
n_rows,
n_cols,
eps,
offset,
casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
elementwise_affine: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
BLOCK_ROW: tl.constexpr,
):
"""
y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N)
Reference:
1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
3. https://arxiv.org/pdf/1910.07467
"""
row_idx = tl.program_id(0) * BLOCK_ROW + tl.arange(0, BLOCK_ROW)
col_offsets = tl.arange(0, BLOCK_SIZE)
row_mask = row_idx < n_rows
col_mask = col_offsets < n_cols
X_row = tl.load(
X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :],
mask=row_mask[:, None] & col_mask[None, :],
other=0,
)
X_row_dtype = X_row.dtype
if elementwise_affine:
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0)
# On Llama, only rstd is computed on fp32
if casting_mode == _CASTING_MODE_LLAMA:
X_row = X_row.to(tl.float32)
# Gemma computes everything on fp32, and then casts back the output to the original dtype
if casting_mode == _CASTING_MODE_GEMMA:
if elementwise_affine:
W_row = W_row.to(tl.float32)
X_row = X_row.to(tl.float32)
if casting_mode == _CASTING_MODE_NONE:
eps = eps.to(X_row_dtype)
offset = offset.to(X_row_dtype)
mean_square = tl.sum(X_row * X_row, axis=1) / n_cols
rstd = rsqrt(mean_square + eps)
# We can save time by caching rms with minimal memory overhead
# because rms is much smaller compared to X_row, as rms is for each row.
# However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd, row_mask)
X_row = X_row * rstd[:, None]
# On Llama, the multiplication with the weight is done on the original dtype
if casting_mode == _CASTING_MODE_LLAMA:
X_row = X_row.to(X_row_dtype)
if elementwise_affine:
Y_row = X_row * (offset + W_row)[None, :]
else:
Y_row = X_row
if casting_mode == _CASTING_MODE_GEMMA:
Y_row = Y_row.to(X_row_dtype)
tl.store(
Y_ptr + row_idx[:, None] * Y_row_stride + col_offsets[None, :],
Y_row,
mask=row_mask[:, None] & col_mask[None, :],
)
@triton.jit
def _block_rms_norm_backward_kernel(
dY_ptr,
dY_row_stride,
dX_ptr,
dX_row_stride,
X_ptr,
X_row_stride,
X_dtype: tl.constexpr,
W_ptr,
W_row_stride,
RSTD_ptr,
RSTD_row_stride,
dW_ptr,
dW_row_stride,
n_rows,
n_cols,
offset,
casting_mode: tl.constexpr,
elementwise_affine: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
BLOCK_ROW: tl.constexpr,
):
"""
dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
dw = sum(dy * (x / RMS)). summation over BxT dimension
"""
pid = tl.program_id(0).cast(tl.int64)
NUM_SMS = tl.num_programs(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
col_mask = col_offsets < n_cols
if elementwise_affine:
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
W_row = W_row + offset
for start in range(pid * BLOCK_ROW, n_rows, NUM_SMS * BLOCK_ROW):
row_idx = start + tl.arange(0, BLOCK_ROW)
row_mask = row_idx < n_rows
dY_row = tl.load(
dY_ptr + row_idx[:, None] * dY_row_stride + col_offsets[None, :],
mask=row_mask[:, None] & col_mask[None, :],
other=0.0,
)
X_row = tl.load(
X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :],
mask=row_mask[:, None] & col_mask[None, :],
other=0.0,
)
# Get cached rms
rstd_row = tl.load(RSTD_ptr + row_idx * RSTD_row_stride, row_mask)
X_row = X_row.to(tl.float32)
# Different bacward graphs for different casting modes
if casting_mode == _CASTING_MODE_LLAMA:
if elementwise_affine:
m = (dY_row * W_row[None, :]).to(tl.float32)
else:
m = dY_row.to(tl.float32)
elif casting_mode == _CASTING_MODE_GEMMA:
dY_row = dY_row.to(tl.float32)
if elementwise_affine:
m = dY_row * W_row[None, :]
else:
m = dY_row
else:
if elementwise_affine:
m = dY_row * W_row[None, :]
else:
m = dY_row
dX_row = rstd_row[:, None] * m
dX_row += (rstd_row[:, None]) * (
-(1 / n_cols) * (rstd_row * rstd_row * tl.sum(m * X_row, axis=1))[:, None] * X_row
)
if elementwise_affine:
if casting_mode == _CASTING_MODE_LLAMA:
# TODO(tcc): use tl.sum(..., dtype=tl.float32) once we upgrade to triton>=3.3.0
dW_row += tl.sum((dY_row * (X_row * rstd_row[:, None]).to(X_dtype)).to(tl.float32), 0)
else:
# here X_row is already in fp32 (see previous if block)
dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
tl.store(
dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :],
dX_row,
mask=row_mask[:, None] & col_mask[None, :],
)
if elementwise_affine:
tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask)
_str_to_casting_mode = {
"llama": _CASTING_MODE_LLAMA.value,
"gemma": _CASTING_MODE_GEMMA.value,
"none": _CASTING_MODE_NONE.value,
}
def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
if not isinstance(casting_mode, int):
assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}"
casting_mode = _str_to_casting_mode[casting_mode]
else:
assert casting_mode in _str_to_casting_mode.values(), f"Invalid casting mode: {casting_mode}"
shape = X.shape
dim = shape[-1]
X = X.view(-1, dim)
n_rows, n_cols = X.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
# RSTD is to cache rstd for each row
# RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode
rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
if W is not None:
# Check constraints.
assert X.shape[1] == W.shape[0], (
"Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
)
elementwise_affine = True
else:
elementwise_affine = False
# XPU-specific optimization
kernel_args = {}
if X.device.type == "xpu":
set_large_grf_mode(kernel_args)
if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
_rms_norm_forward_kernel[(n_rows,)](
Y,
Y.stride(0),
X,
X.stride(0),
W,
W.stride(0) if elementwise_affine else 0,
RSTD,
RSTD.stride(0),
n_cols,
eps,
offset,
casting_mode,
elementwise_affine=elementwise_affine,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
**kernel_args, # XPU-specific optimization
)
else:
BLOCK_ROW = 16
kernel_args["BLOCK_ROW"] = BLOCK_ROW
_block_rms_norm_forward_kernel[(triton.cdiv(n_rows, BLOCK_ROW),)](
Y,
Y.stride(0),
X,
X.stride(0),
W,
W.stride(0) if elementwise_affine else 0,
RSTD,
RSTD.stride(0),
n_rows,
n_cols,
eps,
offset,
casting_mode,
elementwise_affine=elementwise_affine,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
**kernel_args, # XPU-specific optimization
)
return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place, row_mode):
shape = dY.shape
dim = shape[-1]
dY = dY.view(-1, dim)
n_rows, n_cols = dY.shape
sm_count = 1
if X.device.type == "cuda":
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
elif X.device.type == "xpu":
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
elif X.device.type == "npu":
sm_count = get_npu_core_count()
if W is not None:
# fp32 for numerical stability especially.
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
elementwise_affine = True
else:
_dW = None
elementwise_affine = False
if n_cols > BLOCK_SIZE:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
rows_per_program = math.ceil(n_rows / sm_count)
grid = (sm_count,)
if in_place is True:
dX = dY
else:
dX = torch.zeros_like(dY)
# XPU-specific optimization
kernel_args = {}
if X.device.type == "xpu":
set_large_grf_mode(kernel_args)
if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
_rms_norm_backward_kernel[grid](
dY,
dY.stride(0),
dX,
dX.stride(0),
X,
X.stride(0),
torch_to_triton_dtype[X.dtype],
W,
W.stride(0) if elementwise_affine else 0,
RSTD,
RSTD.stride(0),
_dW,
_dW.stride(0) if elementwise_affine else 0,
n_rows,
n_cols,
offset,
rows_per_program,
casting_mode,
elementwise_affine=elementwise_affine,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
**kernel_args, # XPU-specific optimization
)
else:
BLOCK_ROW = 16
kernel_args["BLOCK_ROW"] = BLOCK_ROW
_block_rms_norm_backward_kernel[grid](
dY,
dY.stride(0),
dX,
dX.stride(0),
X,
X.stride(0),
torch_to_triton_dtype[X.dtype],
W,
W.stride(0) if elementwise_affine else 0,
RSTD,
RSTD.stride(0),
_dW,
_dW.stride(0) if elementwise_affine else 0,
n_rows,
n_cols,
offset,
casting_mode,
elementwise_affine=elementwise_affine,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
**kernel_args, # XPU-specific optimization
)
dX = dX.view(*shape)
if elementwise_affine:
dW = _dW.sum(dim=0).to(W.dtype)
else:
dW = None
return dX, dW
class LigerRMSNormFunction(torch.autograd.Function):
"""
Performs RMSNorm (Root Mean Square Normalization), which normalizes the input tensor `X` using the
weight tensor `W`, with an optional offset and casting mode.
Some models use an 'offset' to shift the weight tensor `W` by a constant value. For example, Gemma
uses an offset of 1.0, so the computation becomes `(X / RMS(X)) * (W + 1.0)` instead of the usual
`(X / RMS(X)) * W`. You can pass the offset value as an argument to the forward function.
In addition, different models cast their inputs at different places during RMSNorm computation. For
example, Gemma casts everything to fp32 nefore starting the computation, while Llama casts only the
inverse RMS to fp32. You can specify the casting mode using the `casting_mode` argument. We currently
support the following casting modes (they match HuggingFace Transformers' implementations):
- 'llama': matches the Llama implementation, where only the inverse RMS is computed on fp32.
- 'gemma': matches the Gemma implementation, where everything is cast to fp32, then computed, then cast back to the original dtype.
- 'none': no casting is done. The computation is done in the original dtype. This saves memory and is slightly faster, but has more error w.r.t. the original implementation.
`in_place` option means whether to in_place modify dY to store dX. This is default to `True` to save memory. However, under certain cases, it can produce incorrect inputs.
For example, gemma2 uses two rmsnorm sequentially with residual in between. The resesidual part needs dY so it cannot be modified in-place.
Therefore, for the patching of RMSNorm in gemma2, we set `in_place` to `False`
"""
@staticmethod
@ensure_contiguous
def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True, row_mode=None):
"""
X: (B, T, H) or (BxT, H)
W: (H,)
"""
# Track DTensor metadata for potential reconstruction in backward
ctx.is_dtensor_input = False
ctx.dtensor_device_mesh = None
ctx.dtensor_placements = None
if isinstance(X, torch.distributed.tensor.DTensor):
if _is_hidden_dim_sharded(X):
# Tensor Parallel (TP): hidden dimension is sharded across devices.
# RMSNorm requires the full hidden dimension to compute the RMS,
# so we need to gather the full tensor.
X = X.full_tensor()
else:
# Context Parallel (CP): sequence dimension is sharded.
# RMSNorm computes independently for each sequence position,
# so we can compute locally without gathering.
ctx.is_dtensor_input = True
ctx.dtensor_device_mesh = X.device_mesh
ctx.dtensor_placements = X.placements
X = X.to_local()
Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode, row_mode)
ctx.offset = offset
ctx.casting_mode = casting_mode
ctx.in_place = in_place
ctx.row_mode = row_mode
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.elementwise_affine = W is not None
if W is not None:
ctx.save_for_backward(X, W, RSTD)
else:
ctx.save_for_backward(X, RSTD)
# If input was a CP DTensor, wrap output back into DTensor
if ctx.is_dtensor_input:
Y = torch.distributed.tensor.DTensor.from_local(
Y,
device_mesh=ctx.dtensor_device_mesh,
placements=ctx.dtensor_placements,
)
return Y
@staticmethod
@ensure_contiguous
def backward(ctx, dY):
"""
Y: (B, T, H) or (BxT, H)
"""
if ctx.elementwise_affine:
X, W, RSTD = ctx.saved_tensors
else:
X, RSTD = ctx.saved_tensors
W = None
if isinstance(dY, torch.distributed.tensor.DTensor):
if ctx.is_dtensor_input:
# Context Parallel (CP): sequence dimension is sharded.
# We can compute gradients locally for each sequence position.
dY = dY.to_local()
else:
# Tensor Parallel (TP): hidden dimension is sharded.
# Need to gather the full gradient tensor.
dY = dY.full_tensor()
dX, dW = rms_norm_backward(
dY, X, W, RSTD, ctx.offset, ctx.casting_mode, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place, ctx.row_mode
)
# If input was a CP DTensor, handle output accordingly
if ctx.is_dtensor_input:
# Wrap dX back into DTensor with the same placements
dX = torch.distributed.tensor.DTensor.from_local(
dX,
device_mesh=ctx.dtensor_device_mesh,
placements=ctx.dtensor_placements,
)
# For dW, we need to all-reduce across all sharded mesh dimensions
# since each device only computed gradients for its local sequence positions,
# but the weight is shared across all positions. For multi-dimensional meshes
# (e.g., batch + sequence sharding), we must reduce across each sharded dim.
if dW is not None and _DTENSOR_AVAILABLE and Shard is not None:
for i, placement in enumerate(ctx.dtensor_placements):
if isinstance(placement, Shard):
pg = ctx.dtensor_device_mesh.get_group(mesh_dim=i)
torch.distributed.all_reduce(dW, group=pg)
return dX, dW, None, None, None, None, None