@@ -54,6 +54,7 @@ def _rms_norm_forward_kernel(
5454 eps ,
5555 offset ,
5656 casting_mode : tl .constexpr , # constexpr so the `if` blocks can be optimized out
57+ elementwise_affine : tl .constexpr ,
5758 BLOCK_SIZE : tl .constexpr ,
5859):
5960 """
@@ -75,15 +76,17 @@ def _rms_norm_forward_kernel(
7576
7677 X_row = tl .load (X_ptr + col_offsets , mask = mask , other = 0 )
7778 X_row_dtype = X_row .dtype
78- W_row = tl .load (W_ptr + col_offsets , mask = mask , other = 0 )
79+ if elementwise_affine :
80+ W_row = tl .load (W_ptr + col_offsets , mask = mask , other = 0 )
7981
8082 # On Llama, only rstd is computed on fp32
8183 if casting_mode == _CASTING_MODE_LLAMA :
8284 X_row = X_row .to (tl .float32 )
8385
8486 # Gemma computes everything on fp32, and then casts back the output to the original dtype
8587 if casting_mode == _CASTING_MODE_GEMMA :
86- W_row = W_row .to (tl .float32 )
88+ if elementwise_affine :
89+ W_row = W_row .to (tl .float32 )
8790 X_row = X_row .to (tl .float32 )
8891
8992 if casting_mode == _CASTING_MODE_NONE :
@@ -104,7 +107,10 @@ def _rms_norm_forward_kernel(
104107 if casting_mode == _CASTING_MODE_LLAMA :
105108 X_row = X_row .to (X_row_dtype )
106109
107- Y_row = X_row * (offset + W_row )
110+ if elementwise_affine :
111+ Y_row = X_row * (offset + W_row )
112+ else :
113+ Y_row = X_row
108114
109115 if casting_mode == _CASTING_MODE_GEMMA :
110116 Y_row = Y_row .to (X_row_dtype )
@@ -132,6 +138,7 @@ def _rms_norm_backward_kernel(
132138 offset ,
133139 rows_per_program ,
134140 casting_mode : tl .constexpr ,
141+ elementwise_affine : tl .constexpr ,
135142 BLOCK_SIZE : tl .constexpr ,
136143):
137144 """
@@ -145,16 +152,18 @@ def _rms_norm_backward_kernel(
145152 col_offsets = tl .arange (0 , BLOCK_SIZE )
146153 mask = col_offsets < n_cols
147154
148- dW_row = tl .zeros ((BLOCK_SIZE ,), dtype = tl .float32 )
155+ if elementwise_affine :
156+ dW_row = tl .zeros ((BLOCK_SIZE ,), dtype = tl .float32 )
149157
150158 dY_ptr += row_start * dY_row_stride
151159 dX_ptr += row_start * dX_row_stride
152160
153161 X_ptr += row_start * X_row_stride
154162 RSTD_ptr += row_start
155163
156- W_row = tl .load (W_ptr + col_offsets , mask = mask , other = 0.0 )
157- W_row = W_row + offset
164+ if elementwise_affine :
165+ W_row = tl .load (W_ptr + col_offsets , mask = mask , other = 0.0 )
166+ W_row = W_row + offset
158167
159168 for _ in range (row_start , row_end ):
160169 dY_row = tl .load (dY_ptr + col_offsets , mask = mask , other = 0.0 )
@@ -167,24 +176,34 @@ def _rms_norm_backward_kernel(
167176
168177 # Different bacward graphs for different casting modes
169178 if casting_mode == _CASTING_MODE_LLAMA :
170- m = (dY_row * W_row ).to (tl .float32 )
179+ if elementwise_affine :
180+ m = (dY_row * W_row ).to (tl .float32 )
181+ else :
182+ m = dY_row .to (tl .float32 )
171183
172184 elif casting_mode == _CASTING_MODE_GEMMA :
173185 dY_row = dY_row .to (tl .float32 )
174- m = dY_row * W_row
186+ if elementwise_affine :
187+ m = dY_row * W_row
188+ else :
189+ m = dY_row
175190 else :
176- m = dY_row * W_row
191+ if elementwise_affine :
192+ m = dY_row * W_row
193+ else :
194+ m = dY_row
177195
178196 dX_row = rstd_row * m
179197
180198 dX_row += (rstd_row ) * (- (1 / n_cols ) * rstd_row * rstd_row * tl .sum (m * X_row , axis = 0 ) * X_row )
181199
182- # calculate the gradient of W
183- if casting_mode == _CASTING_MODE_LLAMA :
184- dW_row += dY_row * (X_row * rstd_row ).to (X_dtype )
185- else :
186- # here X_row is already in fp32 (see previous if block)
187- dW_row += dY_row * (X_row * rstd_row )
200+ if elementwise_affine :
201+ # calculate the gradient of W
202+ if casting_mode == _CASTING_MODE_LLAMA :
203+ dW_row += dY_row * (X_row * rstd_row ).to (X_dtype )
204+ else :
205+ # here X_row is already in fp32 (see previous if block)
206+ dW_row += dY_row * (X_row * rstd_row )
188207
189208 tl .store (dX_ptr + col_offsets , dX_row .to (X_dtype ), mask = mask )
190209
@@ -193,7 +212,8 @@ def _rms_norm_backward_kernel(
193212 X_ptr += X_row_stride
194213 RSTD_ptr += RSTD_row_stride
195214
196- tl .store (dW_ptr + row_block_id * dW_row_stride + col_offsets , dW_row , mask = mask )
215+ if elementwise_affine :
216+ tl .store (dW_ptr + row_block_id * dW_row_stride + col_offsets , dW_row , mask = mask )
197217
198218
199219@triton .jit
@@ -211,6 +231,7 @@ def _block_rms_norm_forward_kernel(
211231 eps ,
212232 offset ,
213233 casting_mode : tl .constexpr , # constexpr so the `if` blocks can be optimized out
234+ elementwise_affine : tl .constexpr ,
214235 BLOCK_SIZE : tl .constexpr ,
215236 BLOCK_ROW : tl .constexpr ,
216237):
@@ -234,15 +255,17 @@ def _block_rms_norm_forward_kernel(
234255 other = 0 ,
235256 )
236257 X_row_dtype = X_row .dtype
237- W_row = tl .load (W_ptr + col_offsets , mask = col_mask , other = 0 )
258+ if elementwise_affine :
259+ W_row = tl .load (W_ptr + col_offsets , mask = col_mask , other = 0 )
238260
239261 # On Llama, only rstd is computed on fp32
240262 if casting_mode == _CASTING_MODE_LLAMA :
241263 X_row = X_row .to (tl .float32 )
242264
243265 # Gemma computes everything on fp32, and then casts back the output to the original dtype
244266 if casting_mode == _CASTING_MODE_GEMMA :
245- W_row = W_row .to (tl .float32 )
267+ if elementwise_affine :
268+ W_row = W_row .to (tl .float32 )
246269 X_row = X_row .to (tl .float32 )
247270
248271 if casting_mode == _CASTING_MODE_NONE :
@@ -263,7 +286,10 @@ def _block_rms_norm_forward_kernel(
263286 if casting_mode == _CASTING_MODE_LLAMA :
264287 X_row = X_row .to (X_row_dtype )
265288
266- Y_row = X_row * (offset + W_row )[None , :]
289+ if elementwise_affine :
290+ Y_row = X_row * (offset + W_row )[None , :]
291+ else :
292+ Y_row = X_row
267293
268294 if casting_mode == _CASTING_MODE_GEMMA :
269295 Y_row = Y_row .to (X_row_dtype )
@@ -294,6 +320,7 @@ def _block_rms_norm_backward_kernel(
294320 n_cols ,
295321 offset ,
296322 casting_mode : tl .constexpr ,
323+ elementwise_affine : tl .constexpr ,
297324 BLOCK_SIZE : tl .constexpr ,
298325 BLOCK_ROW : tl .constexpr ,
299326):
@@ -308,10 +335,11 @@ def _block_rms_norm_backward_kernel(
308335 col_offsets = tl .arange (0 , BLOCK_SIZE )
309336 col_mask = col_offsets < n_cols
310337
311- dW_row = tl .zeros ((BLOCK_SIZE ,), dtype = tl .float32 )
338+ if elementwise_affine :
339+ dW_row = tl .zeros ((BLOCK_SIZE ,), dtype = tl .float32 )
312340
313- W_row = tl .load (W_ptr + col_offsets , mask = col_mask , other = 0.0 )
314- W_row = W_row + offset
341+ W_row = tl .load (W_ptr + col_offsets , mask = col_mask , other = 0.0 )
342+ W_row = W_row + offset
315343
316344 for start in range (pid * BLOCK_ROW , n_rows , NUM_SMS * BLOCK_ROW ):
317345 row_idx = start + tl .arange (0 , BLOCK_ROW )
@@ -334,35 +362,45 @@ def _block_rms_norm_backward_kernel(
334362
335363 # Different bacward graphs for different casting modes
336364 if casting_mode == _CASTING_MODE_LLAMA :
337- m = (dY_row * W_row [None , :]).to (tl .float32 )
365+ if elementwise_affine :
366+ m = (dY_row * W_row [None , :]).to (tl .float32 )
367+ else :
368+ m = dY_row .to (tl .float32 )
338369
339370 elif casting_mode == _CASTING_MODE_GEMMA :
340371 dY_row = dY_row .to (tl .float32 )
341- m = dY_row * W_row [None , :]
372+ if elementwise_affine :
373+ m = dY_row * W_row [None , :]
374+ else :
375+ m = dY_row
342376 else :
343- m = dY_row * W_row [None , :]
377+ if elementwise_affine :
378+ m = dY_row * W_row [None , :]
379+ else :
380+ m = dY_row
344381
345382 dX_row = rstd_row [:, None ] * m
346383
347384 dX_row += (rstd_row [:, None ]) * (
348385 - (1 / n_cols ) * (rstd_row * rstd_row * tl .sum (m * X_row , axis = 1 ))[:, None ] * X_row
349386 )
350387
351- # calculate the gradient of W
352- if casting_mode == _CASTING_MODE_LLAMA :
353- # TODO(tcc): use tl.sum(..., dtype=tl.float32) once we upgrade to triton>=3.3.0
354- dW_row += tl .sum ((dY_row * (X_row * rstd_row [:, None ]).to (X_dtype )).to (tl .float32 ), 0 )
355- else :
356- # here X_row is already in fp32 (see previous if block)
357- dW_row += tl .sum (dY_row * (X_row * rstd_row [:, None ]), 0 )
388+ if elementwise_affine :
389+ if casting_mode == _CASTING_MODE_LLAMA :
390+ # TODO(tcc): use tl.sum(..., dtype=tl.float32) once we upgrade to triton>=3.3.0
391+ dW_row += tl .sum ((dY_row * (X_row * rstd_row [:, None ]).to (X_dtype )).to (tl .float32 ), 0 )
392+ else :
393+ # here X_row is already in fp32 (see previous if block)
394+ dW_row += tl .sum (dY_row * (X_row * rstd_row [:, None ]), 0 )
358395
359396 tl .store (
360397 dX_ptr + row_idx [:, None ] * dX_row_stride + col_offsets [None , :],
361398 dX_row ,
362399 mask = row_mask [:, None ] & col_mask [None , :],
363400 )
364401
365- tl .store (dW_ptr + pid * dW_row_stride + col_offsets , dW_row , mask = col_mask )
402+ if elementwise_affine :
403+ tl .store (dW_ptr + pid * dW_row_stride + col_offsets , dW_row , mask = col_mask )
366404
367405
368406_str_to_casting_mode = {
@@ -391,8 +429,14 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
391429 rstd_dtype = torch .float32 if casting_mode in (_CASTING_MODE_LLAMA .value , _CASTING_MODE_GEMMA .value ) else X .dtype
392430 RSTD = torch .empty (n_rows , dtype = rstd_dtype , device = X .device )
393431
394- # Check constraints.
395- assert X .shape [1 ] == W .shape [0 ], "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
432+ if W is not None :
433+ # Check constraints.
434+ assert X .shape [1 ] == W .shape [0 ], (
435+ "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
436+ )
437+ elementwise_affine = True
438+ else :
439+ elementwise_affine = False
396440
397441 # XPU-specific optimization
398442 kernel_args = {}
@@ -405,13 +449,14 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
405449 X ,
406450 X .stride (0 ),
407451 W ,
408- W .stride (0 ),
452+ W .stride (0 ) if elementwise_affine else 0 ,
409453 RSTD ,
410454 RSTD .stride (0 ),
411455 n_cols ,
412456 eps ,
413457 offset ,
414458 casting_mode ,
459+ elementwise_affine = elementwise_affine ,
415460 BLOCK_SIZE = BLOCK_SIZE ,
416461 num_warps = num_warps ,
417462 ** kernel_args , # XPU-specific optimization
@@ -425,14 +470,15 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
425470 X ,
426471 X .stride (0 ),
427472 W ,
428- W .stride (0 ),
473+ W .stride (0 ) if elementwise_affine else 0 ,
429474 RSTD ,
430475 RSTD .stride (0 ),
431476 n_rows ,
432477 n_cols ,
433478 eps ,
434479 offset ,
435480 casting_mode ,
481+ elementwise_affine = elementwise_affine ,
436482 BLOCK_SIZE = BLOCK_SIZE ,
437483 num_warps = num_warps ,
438484 ** kernel_args , # XPU-specific optimization
@@ -454,8 +500,13 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
454500 elif X .device .type == "npu" :
455501 sm_count = get_npu_multi_processor_count ()
456502
457- # fp32 for numerical stability especially.
458- _dW = torch .empty ((sm_count , n_cols ), dtype = torch .float32 , device = W .device )
503+ if W is not None :
504+ # fp32 for numerical stability especially.
505+ _dW = torch .empty ((sm_count , n_cols ), dtype = torch .float32 , device = W .device )
506+ elementwise_affine = True
507+ else :
508+ _dW = None
509+ elementwise_affine = False
459510
460511 if n_cols > BLOCK_SIZE :
461512 raise RuntimeError ("This layer norm doesn't support feature dim >= 64KB." )
@@ -482,16 +533,17 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
482533 X .stride (0 ),
483534 torch_to_triton_dtype [X .dtype ],
484535 W ,
485- W .stride (0 ),
536+ W .stride (0 ) if elementwise_affine else 0 ,
486537 RSTD ,
487538 RSTD .stride (0 ),
488539 _dW ,
489- _dW .stride (0 ),
540+ _dW .stride (0 ) if elementwise_affine else 0 ,
490541 n_rows ,
491542 n_cols ,
492543 offset ,
493544 rows_per_program ,
494545 casting_mode ,
546+ elementwise_affine = elementwise_affine ,
495547 BLOCK_SIZE = BLOCK_SIZE ,
496548 num_warps = num_warps ,
497549 ** kernel_args , # XPU-specific optimization
@@ -508,21 +560,26 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
508560 X .stride (0 ),
509561 torch_to_triton_dtype [X .dtype ],
510562 W ,
511- W .stride (0 ),
563+ W .stride (0 ) if elementwise_affine else 0 ,
512564 RSTD ,
513565 RSTD .stride (0 ),
514566 _dW ,
515- _dW .stride (0 ),
567+ _dW .stride (0 ) if elementwise_affine else 0 ,
516568 n_rows ,
517569 n_cols ,
518570 offset ,
519571 casting_mode ,
572+ elementwise_affine = elementwise_affine ,
520573 BLOCK_SIZE = BLOCK_SIZE ,
521574 num_warps = num_warps ,
522575 ** kernel_args , # XPU-specific optimization
523576 )
524577 dX = dX .view (* shape )
525- dW = _dW .sum (dim = 0 ).to (W .dtype )
578+
579+ if elementwise_affine :
580+ dW = _dW .sum (dim = 0 ).to (W .dtype )
581+ else :
582+ dW = None
526583
527584 return dX , dW
528585
@@ -563,7 +620,11 @@ def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True, row
563620 ctx .row_mode = row_mode
564621 ctx .BLOCK_SIZE = BLOCK_SIZE
565622 ctx .num_warps = num_warps
566- ctx .save_for_backward (X , W , RSTD )
623+ ctx .elementwise_affine = W is not None
624+ if W is not None :
625+ ctx .save_for_backward (X , W , RSTD )
626+ else :
627+ ctx .save_for_backward (X , RSTD )
567628 return Y
568629
569630 @staticmethod
@@ -572,7 +633,11 @@ def backward(ctx, dY):
572633 """
573634 Y: (B, T, H) or (BxT, H)
574635 """
575- X , W , RSTD = ctx .saved_tensors
636+ if ctx .elementwise_affine :
637+ X , W , RSTD = ctx .saved_tensors
638+ else :
639+ X , RSTD = ctx .saved_tensors
640+ W = None
576641 dX , dW = rms_norm_backward (
577642 dY , X , W , RSTD , ctx .offset , ctx .casting_mode , ctx .BLOCK_SIZE , ctx .num_warps , ctx .in_place , ctx .row_mode
578643 )
0 commit comments