Skip to content

Commit cd6ec32

Browse files
Tcc0403lancerts
andauthored
Fix code style (#736)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence Signed-off-by: Tcc0403 <[email protected]> Co-authored-by: Shao Tang <[email protected]>
1 parent b828275 commit cd6ec32

File tree

3 files changed

+34
-20
lines changed

3 files changed

+34
-20
lines changed

src/liger_kernel/ops/rms_norm.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ def _rms_norm_backward_kernel(
193193

194194
tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
195195

196+
196197
@triton.jit
197198
def _block_rms_norm_forward_kernel(
198199
Y_ptr,
@@ -225,8 +226,11 @@ def _block_rms_norm_forward_kernel(
225226
row_mask = row_idx < n_rows
226227
col_mask = col_offsets < n_cols
227228

228-
229-
X_row = tl.load(X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :], mask=row_mask[:, None] & col_mask[None, :] , other=0)
229+
X_row = tl.load(
230+
X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :],
231+
mask=row_mask[:, None] & col_mask[None, :],
232+
other=0,
233+
)
230234
X_row_dtype = X_row.dtype
231235
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0)
232236

@@ -262,7 +266,12 @@ def _block_rms_norm_forward_kernel(
262266
if casting_mode == _CASTING_MODE_GEMMA:
263267
Y_row = Y_row.to(X_row_dtype)
264268

265-
tl.store(Y_ptr + row_idx[:, None] * Y_row_stride + col_offsets[None, :], Y_row, mask=row_mask[:, None] & col_mask[None, :])
269+
tl.store(
270+
Y_ptr + row_idx[:, None] * Y_row_stride + col_offsets[None, :],
271+
Y_row,
272+
mask=row_mask[:, None] & col_mask[None, :],
273+
)
274+
266275

267276
@triton.jit
268277
def _block_rms_norm_backward_kernel(
@@ -306,8 +315,16 @@ def _block_rms_norm_backward_kernel(
306315
for start in range(pid * BLOCK_ROW, n_rows, NUM_SMS * BLOCK_ROW):
307316
row_idx = start + tl.arange(0, BLOCK_ROW)
308317
row_mask = row_idx < n_rows
309-
dY_row = tl.load(dY_ptr + row_idx[:, None] * dY_row_stride + col_offsets[None, :], mask=row_mask[:, None] & col_mask[None, :], other=0.0)
310-
X_row = tl.load(X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :], mask=row_mask[:, None] & col_mask[None, :], other=0.0)
318+
dY_row = tl.load(
319+
dY_ptr + row_idx[:, None] * dY_row_stride + col_offsets[None, :],
320+
mask=row_mask[:, None] & col_mask[None, :],
321+
other=0.0,
322+
)
323+
X_row = tl.load(
324+
X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :],
325+
mask=row_mask[:, None] & col_mask[None, :],
326+
other=0.0,
327+
)
311328

312329
# Get cached rms
313330
rstd_row = tl.load(RSTD_ptr + row_idx * RSTD_row_stride, row_mask)
@@ -326,7 +343,9 @@ def _block_rms_norm_backward_kernel(
326343

327344
dX_row = rstd_row[:, None] * m
328345

329-
dX_row += (rstd_row[:, None]) * (-(1 / n_cols) * (rstd_row * rstd_row * tl.sum(m * X_row, axis=1))[:, None] * X_row)
346+
dX_row += (rstd_row[:, None]) * (
347+
-(1 / n_cols) * (rstd_row * rstd_row * tl.sum(m * X_row, axis=1))[:, None] * X_row
348+
)
330349

331350
# calculate the gradient of W
332351
if casting_mode == _CASTING_MODE_LLAMA:
@@ -335,8 +354,11 @@ def _block_rms_norm_backward_kernel(
335354
# here X_row is already in fp32 (see previous if block)
336355
dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
337356

338-
tl.store(dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :], dX_row, mask=row_mask[:, None] & col_mask[None, :])
339-
357+
tl.store(
358+
dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :],
359+
dX_row,
360+
mask=row_mask[:, None] & col_mask[None, :],
361+
)
340362

341363
tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask)
342364

@@ -549,15 +571,6 @@ def backward(ctx, dY):
549571
"""
550572
X, W, RSTD = ctx.saved_tensors
551573
dX, dW = rms_norm_backward(
552-
dY,
553-
X,
554-
W,
555-
RSTD,
556-
ctx.offset,
557-
ctx.casting_mode,
558-
ctx.BLOCK_SIZE,
559-
ctx.num_warps,
560-
ctx.in_place,
561-
ctx.row_mode
574+
dY, X, W, RSTD, ctx.offset, ctx.casting_mode, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place, ctx.row_mode
562575
)
563576
return dX, dW, None, None, None, None, None

src/liger_kernel/transformers/monkey_patch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -776,7 +776,8 @@ def apply_liger_kernel_to_gemma3_text(
776776

777777
from transformers.models.gemma3 import modeling_gemma3
778778
from transformers.models.gemma3.modeling_gemma3 import Gemma3DecoderLayer
779-
from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM, Gemma3TextModel
779+
from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM
780+
from transformers.models.gemma3.modeling_gemma3 import Gemma3TextModel
780781

781782
from liger_kernel.transformers.gema3_rms import LigerRMSNormForGemma3
782783
from liger_kernel.transformers.model.gemma3 import causal_forward

src/liger_kernel/transformers/rms_norm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def forward(self, hidden_states):
3737
self.offset,
3838
self.casting_mode,
3939
self.in_place,
40-
self.row_mode
40+
self.row_mode,
4141
)
4242

4343
def extra_repr(self):

0 commit comments

Comments
 (0)