@@ -193,6 +193,153 @@ def _rms_norm_backward_kernel(
193193
194194 tl .store (dW_ptr + row_block_id * dW_row_stride + col_offsets , dW_row , mask = mask )
195195
196+ @triton .jit
197+ def _block_rms_norm_forward_kernel (
198+ Y_ptr ,
199+ Y_row_stride ,
200+ X_ptr ,
201+ X_row_stride ,
202+ W_ptr ,
203+ W_row_stride ,
204+ RSTD_ptr ,
205+ RSTD_row_stride ,
206+ n_rows ,
207+ n_cols ,
208+ eps ,
209+ offset ,
210+ casting_mode : tl .constexpr , # constexpr so the `if` blocks can be optimized out
211+ BLOCK_SIZE : tl .constexpr ,
212+ BLOCK_ROW : tl .constexpr ,
213+ ):
214+ """
215+ y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N)
216+
217+ Reference:
218+ 1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
219+ 2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
220+ 3. https://arxiv.org/pdf/1910.07467
221+ """
222+
223+ row_idx = tl .program_id (0 ) * BLOCK_ROW + tl .arange (0 , BLOCK_ROW )
224+ col_offsets = tl .arange (0 , BLOCK_SIZE )
225+ row_mask = row_idx < n_rows
226+ col_mask = col_offsets < n_cols
227+
228+
229+ X_row = tl .load (X_ptr + row_idx [:, None ] * X_row_stride + col_offsets [None , :], mask = row_mask [:, None ] & col_mask [None , :] , other = 0 )
230+ X_row_dtype = X_row .dtype
231+ W_row = tl .load (W_ptr + col_offsets , mask = col_mask , other = 0 )
232+
233+ # On Llama, only rstd is computed on fp32
234+ if casting_mode == _CASTING_MODE_LLAMA :
235+ X_row = X_row .to (tl .float32 )
236+
237+ # Gemma computes everything on fp32, and then casts back the output to the original dtype
238+ if casting_mode == _CASTING_MODE_GEMMA :
239+ W_row = W_row .to (tl .float32 )
240+ X_row = X_row .to (tl .float32 )
241+
242+ if casting_mode == _CASTING_MODE_NONE :
243+ eps = eps .to (X_row_dtype )
244+ offset = offset .to (X_row_dtype )
245+
246+ mean_square = tl .sum (X_row * X_row , axis = 1 ) / n_cols
247+ rstd = rsqrt (mean_square + eps )
248+
249+ # We can save time by caching rms with minimal memory overhead
250+ # because rms is much smaller compared to X_row, as rms is for each row.
251+ # However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
252+ tl .store (RSTD_ptr + row_idx * RSTD_row_stride , rstd , row_mask )
253+
254+ X_row = X_row * rstd [:, None ]
255+
256+ # On Llama, the multiplication with the weight is done on the original dtype
257+ if casting_mode == _CASTING_MODE_LLAMA :
258+ X_row = X_row .to (X_row_dtype )
259+
260+ Y_row = X_row * (offset + W_row )[None , :]
261+
262+ if casting_mode == _CASTING_MODE_GEMMA :
263+ Y_row = Y_row .to (X_row_dtype )
264+
265+ tl .store (Y_ptr + row_idx [:, None ] * Y_row_stride + col_offsets [None , :], Y_row , mask = row_mask [:, None ] & col_mask [None , :])
266+
267+ @triton .jit
268+ def _block_rms_norm_backward_kernel (
269+ dY_ptr ,
270+ dY_row_stride ,
271+ dX_ptr ,
272+ dX_row_stride ,
273+ X_ptr ,
274+ X_row_stride ,
275+ X_dtype : tl .constexpr ,
276+ W_ptr ,
277+ W_row_stride ,
278+ RSTD_ptr ,
279+ RSTD_row_stride ,
280+ dW_ptr ,
281+ dW_row_stride ,
282+ n_rows ,
283+ n_cols ,
284+ offset ,
285+ rows_per_program : tl .constexpr ,
286+ casting_mode : tl .constexpr ,
287+ BLOCK_SIZE : tl .constexpr ,
288+ BLOCK_ROW : tl .constexpr ,
289+ ):
290+ """
291+ dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
292+ dw = sum(dy * (x / RMS)). summation over BxT dimension
293+ """
294+
295+ pid = tl .program_id (0 ).cast (tl .int64 )
296+ NUM_SMS = tl .num_programs (0 )
297+
298+ col_offsets = tl .arange (0 , BLOCK_SIZE )
299+ col_mask = col_offsets < n_cols
300+
301+ dW_row = tl .zeros ((BLOCK_SIZE ,), dtype = tl .float32 )
302+
303+ W_row = tl .load (W_ptr + col_offsets , mask = col_mask , other = 0.0 )
304+ W_row = W_row + offset
305+
306+ for start in range (pid * BLOCK_ROW , n_rows , NUM_SMS * BLOCK_ROW ):
307+ row_idx = start + tl .arange (0 , BLOCK_ROW )
308+ row_mask = row_idx < n_rows
309+ dY_row = tl .load (dY_ptr + row_idx [:, None ] * dY_row_stride + col_offsets [None , :], mask = row_mask [:, None ] & col_mask [None , :], other = 0.0 )
310+ X_row = tl .load (X_ptr + row_idx [:, None ] * X_row_stride + col_offsets [None , :], mask = row_mask [:, None ] & col_mask [None , :], other = 0.0 )
311+
312+ # Get cached rms
313+ rstd_row = tl .load (RSTD_ptr + row_idx * RSTD_row_stride , row_mask )
314+
315+ X_row = X_row .to (tl .float32 )
316+
317+ # Different bacward graphs for different casting modes
318+ if casting_mode == _CASTING_MODE_LLAMA :
319+ m = (dY_row * W_row [None , :]).to (tl .float32 )
320+
321+ elif casting_mode == _CASTING_MODE_GEMMA :
322+ dY_row = dY_row .to (tl .float32 )
323+ m = dY_row * W_row [None , :]
324+ else :
325+ m = dY_row * W_row [None , :]
326+
327+ dX_row = rstd_row [:, None ] * m
328+
329+ dX_row += (rstd_row [:, None ]) * (- (1 / n_cols ) * (rstd_row * rstd_row * tl .sum (m * X_row , axis = 1 ))[:, None ] * X_row )
330+
331+ # calculate the gradient of W
332+ if casting_mode == _CASTING_MODE_LLAMA :
333+ dW_row += tl .sum (dY_row * (X_row * rstd_row [:, None ]).to (X_dtype ), 0 )
334+ else :
335+ # here X_row is already in fp32 (see previous if block)
336+ dW_row += tl .sum (dY_row * (X_row * rstd_row [:, None ]), 0 )
337+
338+ tl .store (dX_ptr + row_idx [:, None ] * dX_row_stride + col_offsets [None , :], dX_row , mask = row_mask [:, None ] & col_mask [None , :])
339+
340+
341+ tl .store (dW_ptr + pid * dW_row_stride + col_offsets , dW_row , mask = col_mask )
342+
196343
197344_str_to_casting_mode = {
198345 "llama" : _CASTING_MODE_LLAMA .value ,
@@ -201,7 +348,7 @@ def _rms_norm_backward_kernel(
201348}
202349
203350
204- def rms_norm_forward (X , W , eps , offset , casting_mode ):
351+ def rms_norm_forward (X , W , eps , offset , casting_mode , row_mode ):
205352 if not isinstance (casting_mode , int ):
206353 assert casting_mode in _str_to_casting_mode , f"Invalid casting mode: { casting_mode } "
207354 casting_mode = _str_to_casting_mode [casting_mode ]
@@ -227,27 +374,49 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
227374 kernel_args = {}
228375 if X .device .type == "xpu" :
229376 kernel_args ["grf_mode" ] = "large"
230- _rms_norm_forward_kernel [(n_rows ,)](
231- Y ,
232- Y .stride (0 ),
233- X ,
234- X .stride (0 ),
235- W ,
236- W .stride (0 ),
237- RSTD ,
238- RSTD .stride (0 ),
239- n_cols ,
240- eps ,
241- offset ,
242- casting_mode ,
243- BLOCK_SIZE = BLOCK_SIZE ,
244- num_warps = num_warps ,
245- ** kernel_args , # XPU-specific optimization
246- )
377+ if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode :
378+ _rms_norm_forward_kernel [(n_rows ,)](
379+ Y ,
380+ Y .stride (0 ),
381+ X ,
382+ X .stride (0 ),
383+ W ,
384+ W .stride (0 ),
385+ RSTD ,
386+ RSTD .stride (0 ),
387+ n_cols ,
388+ eps ,
389+ offset ,
390+ casting_mode ,
391+ BLOCK_SIZE = BLOCK_SIZE ,
392+ num_warps = num_warps ,
393+ ** kernel_args , # XPU-specific optimization
394+ )
395+ else :
396+ BLOCK_ROW = 16
397+ kernel_args ["BLOCK_ROW" ] = BLOCK_ROW
398+ _block_rms_norm_forward_kernel [(triton .cdiv (n_rows , BLOCK_ROW ),)](
399+ Y ,
400+ Y .stride (0 ),
401+ X ,
402+ X .stride (0 ),
403+ W ,
404+ W .stride (0 ),
405+ RSTD ,
406+ RSTD .stride (0 ),
407+ n_rows ,
408+ n_cols ,
409+ eps ,
410+ offset ,
411+ casting_mode ,
412+ BLOCK_SIZE = BLOCK_SIZE ,
413+ num_warps = num_warps ,
414+ ** kernel_args , # XPU-specific optimization
415+ )
247416 return Y .view (* shape ), X , RSTD , BLOCK_SIZE , num_warps , casting_mode
248417
249418
250- def rms_norm_backward (dY , X , W , RSTD , offset , casting_mode , BLOCK_SIZE , num_warps , in_place ):
419+ def rms_norm_backward (dY , X , W , RSTD , offset , casting_mode , BLOCK_SIZE , num_warps , in_place , row_mode ):
251420 shape = dY .shape
252421 dim = shape [- 1 ]
253422 dY = dY .view (- 1 , dim )
@@ -277,29 +446,56 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
277446 if X .device .type == "xpu" :
278447 kernel_args ["grf_mode" ] = "large"
279448
280- _rms_norm_backward_kernel [grid ](
281- dY ,
282- dY .stride (0 ),
283- dX ,
284- dX .stride (0 ),
285- X ,
286- X .stride (0 ),
287- torch_to_triton_dtype [X .dtype ],
288- W ,
289- W .stride (0 ),
290- RSTD ,
291- RSTD .stride (0 ),
292- _dW ,
293- _dW .stride (0 ),
294- n_rows ,
295- n_cols ,
296- offset ,
297- rows_per_program ,
298- casting_mode ,
299- BLOCK_SIZE = BLOCK_SIZE ,
300- num_warps = num_warps ,
301- ** kernel_args , # XPU-specific optimization
302- )
449+ if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode :
450+ _rms_norm_backward_kernel [grid ](
451+ dY ,
452+ dY .stride (0 ),
453+ dX ,
454+ dX .stride (0 ),
455+ X ,
456+ X .stride (0 ),
457+ torch_to_triton_dtype [X .dtype ],
458+ W ,
459+ W .stride (0 ),
460+ RSTD ,
461+ RSTD .stride (0 ),
462+ _dW ,
463+ _dW .stride (0 ),
464+ n_rows ,
465+ n_cols ,
466+ offset ,
467+ rows_per_program ,
468+ casting_mode ,
469+ BLOCK_SIZE = BLOCK_SIZE ,
470+ num_warps = num_warps ,
471+ ** kernel_args , # XPU-specific optimization
472+ )
473+ else :
474+ BLOCK_ROW = 16
475+ kernel_args ["BLOCK_ROW" ] = BLOCK_ROW
476+ _block_rms_norm_backward_kernel [grid ](
477+ dY ,
478+ dY .stride (0 ),
479+ dX ,
480+ dX .stride (0 ),
481+ X ,
482+ X .stride (0 ),
483+ torch_to_triton_dtype [X .dtype ],
484+ W ,
485+ W .stride (0 ),
486+ RSTD ,
487+ RSTD .stride (0 ),
488+ _dW ,
489+ _dW .stride (0 ),
490+ n_rows ,
491+ n_cols ,
492+ offset ,
493+ rows_per_program ,
494+ casting_mode ,
495+ BLOCK_SIZE = BLOCK_SIZE ,
496+ num_warps = num_warps ,
497+ ** kernel_args , # XPU-specific optimization
498+ )
303499 dX = dX .view (* shape )
304500 dW = _dW .sum (dim = 0 ).to (W .dtype )
305501
@@ -330,15 +526,16 @@ class LigerRMSNormFunction(torch.autograd.Function):
330526
331527 @staticmethod
332528 @ensure_contiguous
333- def forward (ctx , X , W , eps , offset = 0.0 , casting_mode = "llama" , in_place = True ):
529+ def forward (ctx , X , W , eps , offset = 0.0 , casting_mode = "llama" , in_place = True , row_mode = None ):
334530 """
335531 X: (B, T, H) or (BxT, H)
336532 W: (H,)
337533 """
338- Y , X , RSTD , BLOCK_SIZE , num_warps , casting_mode = rms_norm_forward (X , W , eps , offset , casting_mode )
534+ Y , X , RSTD , BLOCK_SIZE , num_warps , casting_mode = rms_norm_forward (X , W , eps , offset , casting_mode , row_mode )
339535 ctx .offset = offset
340536 ctx .casting_mode = casting_mode
341537 ctx .in_place = in_place
538+ ctx .row_mode = row_mode
342539 ctx .BLOCK_SIZE = BLOCK_SIZE
343540 ctx .num_warps = num_warps
344541 ctx .save_for_backward (X , W , RSTD )
@@ -361,5 +558,6 @@ def backward(ctx, dY):
361558 ctx .BLOCK_SIZE ,
362559 ctx .num_warps ,
363560 ctx .in_place ,
561+ ctx .row_mode
364562 )
365- return dX , dW , None , None , None , None
563+ return dX , dW , None , None , None , None , None
0 commit comments