Skip to content

Commit 268c18e

Browse files
author
sangchengmeng
committed
fix rms_norm
1 parent feb505b commit 268c18e

File tree

10 files changed

+121
-137
lines changed

10 files changed

+121
-137
lines changed

lightllm/models/chatglm2/layer_infer/transformer_layer_infer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
from lightllm.models.chatglm2.triton_kernel.rotary_emb import rotary_emb_fwd
1212
from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv, destindex_copy_quantize_kv
13-
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward
1413

1514

1615
class ChatGLM2TransformerLayerInfer(LlamaTransformerLayerInfer):

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -154,16 +154,16 @@ def _get_qkv(
154154
q = layer_weight.q_weight_.mm(input)
155155
else:
156156
q = layer_weight.q_a_proj_.mm(input)
157-
rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_, out=q)
157+
q = rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_, use_custom_tensor_mananger=True)
158158
q = layer_weight.q_b_proj_.mm(q)
159159
q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim)
160160
q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
161161
layer_weight.kv_a_proj_with_mqa_.mm(input, out=cache_kv.view(-1, self.kv_lora_rank + self.qk_rope_head_dim))
162-
rmsnorm_forward(
162+
cache_kv[:, :, : self.kv_lora_rank] = rmsnorm_forward(
163163
cache_kv[:, :, : self.kv_lora_rank],
164164
weight=layer_weight.kv_a_layernorm_.weight,
165165
eps=self.eps_,
166-
out=cache_kv[:, :, : self.kv_lora_rank],
166+
use_custom_tensor_mananger=True
167167
)
168168

169169
rotary_emb_fwd(
@@ -191,16 +191,16 @@ def _tpsp_get_qkv(
191191
q = layer_weight.q_weight_.mm(input)
192192
else:
193193
q = layer_weight.q_a_proj_.mm(input)
194-
rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_, out=q)
194+
q = rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_)
195195
q = layer_weight.q_b_proj_.mm(q)
196196
q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim)
197197
q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
198198
layer_weight.kv_a_proj_with_mqa_.mm(input, out=cache_kv.view(-1, self.kv_lora_rank + self.qk_rope_head_dim))
199-
rmsnorm_forward(
199+
cache_kv[:, :, : self.kv_lora_rank] = rmsnorm_forward(
200200
cache_kv[:, :, : self.kv_lora_rank],
201201
weight=layer_weight.kv_a_layernorm_.weight,
202202
eps=self.eps_,
203-
out=cache_kv[:, :, : self.kv_lora_rank],
203+
use_custom_tensor_mananger=True
204204
)
205205
rotary_emb_fwd(
206206
q_rope,

lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ def _mtp_context_forward(
2020
):
2121
tgt_embdings = infer_state.deepseekv3_mtp_draft_input_hiddens
2222
assert input_embdings.shape[0] == tgt_embdings.shape[0]
23-
rmsnorm_forward(input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, out=input_embdings)
24-
rmsnorm_forward(tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, out=tgt_embdings)
23+
input_embdings = rmsnorm_forward(input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, use_custom_tensor_mananger=True)
24+
tgt_embdings = rmsnorm_forward(tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, use_custom_tensor_mananger=True)
2525

2626
cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1)
2727

@@ -36,8 +36,8 @@ def _mtp_token_forward(
3636
):
3737
tgt_embdings = infer_state.deepseekv3_mtp_draft_input_hiddens
3838
assert input_embdings.shape[0] == tgt_embdings.shape[0]
39-
rmsnorm_forward(input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, out=input_embdings)
40-
rmsnorm_forward(tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, out=tgt_embdings)
39+
input_embdings = rmsnorm_forward(input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, use_custom_tensor_mananger=True)
40+
tgt_embdings = rmsnorm_forward(tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, use_custom_tensor_mananger=True)
4141

4242
cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1)
4343

lightllm/models/llama/layer_infer/post_layer_infer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88
from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight
99
from einops import rearrange
1010
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
11-
from lightllm.models.vit.triton_kernel.rms_norm_vit import rms_norm
11+
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward
1212
from lightllm.common.basemodel import PostLayerInferTpl
1313
from lightllm.utils.infer_utils import mark_cost_time
1414
from lightllm.distributed.communication_op import all_gather
1515

1616

17+
1718
class LlamaPostLayerInfer(PostLayerInferTpl):
1819
""" """
1920

@@ -25,7 +26,7 @@ def __init__(self, network_config, mode):
2526
return
2627

2728
def _norm(self, input, infer_state, layer_weight: LlamaPreAndPostLayerWeight) -> torch.Tensor:
28-
return rms_norm(input, layer_weight.final_norm_weight_, eps=self.eps_, use_custom_tensor_mananger=True)
29+
return rmsnorm_forward(input, layer_weight.final_norm_weight_, eps=self.eps_, use_custom_tensor_mananger=True)
2930

3031
def _slice_get_last_input(self, input_embdings, infer_state: LlamaInferStateInfo):
3132

lightllm/models/llama/layer_infer/transformer_layer_infer.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import torch
23
import triton
34
import torch.functional as F
@@ -14,7 +15,7 @@
1415
from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd, token_att_fwd_int8k
1516
from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd
1617
from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2, token_att_fwd2_int8v
17-
from lightllm.models.vit.triton_kernel.rms_norm_vit import rms_norm
18+
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward
1819
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd
1920
from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd
2021

@@ -32,7 +33,6 @@
3233

3334
from lightllm.utils.sgl_utils import flash_attn_with_kvcache
3435

35-
3636
class LlamaTransformerLayerInfer(TransformerLayerInferTpl):
3737
""" """
3838

@@ -134,16 +134,12 @@ def _bind_attention(self):
134134
def _att_norm(
135135
self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight
136136
) -> torch.Tensor:
137-
out = self.alloc_tensor(input.shape, input.dtype)
138-
rms_norm(input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_, use_custom_tensor_mananger=True)
139-
return out
137+
return rmsnorm_forward(input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_, use_custom_tensor_mananger=True)
140138

141139
def _ffn_norm(
142140
self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight
143141
) -> torch.Tensor:
144-
out = self.alloc_tensor(input.shape, input.dtype)
145-
rms_norm(input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_, use_custom_tensor_mananger=True)
146-
return out
142+
return rmsnorm_forward(input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_, use_custom_tensor_mananger=True)
147143

148144
def _get_qkv(
149145
self, input, cache_kv, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight

lightllm/models/llama/triton_kernel/rmsnorm.py

Lines changed: 97 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1+
import os
12
import torch
2-
33
import triton
44
import triton.language as tl
5-
5+
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
66

77
@triton.jit
8-
def _rms_norm_fwd_fused(
8+
def _rms_norm_low_accuracy_kernel(
99
X, # pointer to the input
1010
Y, # pointer to the output
1111
W, # pointer to the weights
@@ -41,9 +41,15 @@ def _rms_norm_fwd_fused(
4141
tl.store(Y + cols * y_stride1, y.to(Y.dtype.element_ty), mask=mask)
4242

4343

44-
def rmsnorm_forward(x: torch.Tensor, weight, eps, out=None):
44+
def rmsnorm_forward_low_accuracy(x: torch.Tensor, weight, eps, use_custom_tensor_mananger: bool = False):
4545
# allocate output
46-
y = torch.empty_like(x) if out is None else out
46+
if use_custom_tensor_mananger:
47+
shape = x.shape
48+
dtype = x.dtype
49+
device = x.device
50+
y = g_cache_manager.alloc_tensor(shape, dtype, device=device)
51+
else:
52+
y = torch.empty_like(x)
4753
# reshape input data into 2D tensor
4854
x_arg = x.view(-1, x.shape[-1])
4955
y_arg = y.view(-1, x.shape[-1])
@@ -61,7 +67,7 @@ def rmsnorm_forward(x: torch.Tensor, weight, eps, out=None):
6167
if BLOCK_SIZE > 16384:
6268
BLOCK_SIZE = 16384
6369
# enqueue kernel
64-
_rms_norm_fwd_fused[(M,)](
70+
_rms_norm_low_accuracy_kernel[(M,)](
6571
x_arg,
6672
y_arg,
6773
weight,
@@ -77,6 +83,80 @@ def rmsnorm_forward(x: torch.Tensor, weight, eps, out=None):
7783
return y
7884

7985

86+
@triton.jit
87+
def _rms_norm_high_accuracy_kernel(
88+
input,
89+
weight,
90+
output,
91+
in_row_stride: tl.constexpr,
92+
in_col_stride: tl.constexpr,
93+
out_row_stride: tl.constexpr,
94+
out_col_stride: tl.constexpr,
95+
eps: tl.constexpr,
96+
N_COLS: tl.constexpr,
97+
BLOCK_N: tl.constexpr,
98+
):
99+
"""Rms norm kernel."""
100+
prog_id = tl.program_id(0)
101+
offsets = tl.arange(0, BLOCK_N)
102+
103+
w = tl.load(weight + offsets, mask=offsets < N_COLS, other=0.0)
104+
105+
x_ptr = input + prog_id * in_row_stride
106+
x = tl.load(x_ptr + offsets * in_col_stride, mask=offsets < N_COLS, other=0.0)
107+
xf = x.to(tl.float32)
108+
109+
var = tl.sum(xf * xf, 0) * float(1.0 / N_COLS)
110+
out = xf / tl.sqrt(var + eps)
111+
out = (w * out).to(x.dtype)
112+
113+
out_ptr = output + prog_id * out_row_stride
114+
tl.store(out_ptr + offsets * out_col_stride, out, mask=offsets < N_COLS)
115+
116+
117+
def rmsnorm_forward_high_accuracy(hidden_states: torch.Tensor, weight: torch.Tensor, eps: float = 1e-5, use_custom_tensor_mananger: bool = False):
118+
"""Rms norm."""
119+
120+
assert hidden_states.is_contiguous(), "hidden_states must be contiguous"
121+
122+
origin_shape = hidden_states.shape
123+
hidden_dim = weight.shape[0]
124+
assert hidden_dim == origin_shape[-1], f"hidden_dim {hidden_dim} != {origin_shape[-1]}"
125+
126+
rows = hidden_states.numel() // hidden_dim
127+
if hidden_states.dim() == 3: # (bs, seq_len, hidden_dim)
128+
hidden_states = hidden_states.view(rows, hidden_dim)
129+
130+
in_row_stride, in_col_stride = hidden_states.stride(0), hidden_states.stride(1)
131+
132+
BLOCK_N = triton.next_power_of_2(hidden_dim)
133+
if use_custom_tensor_mananger:
134+
shape = hidden_states.shape
135+
dtype = hidden_states.dtype
136+
device = hidden_states.device
137+
output = g_cache_manager.alloc_tensor(shape, dtype, device=device)
138+
else:
139+
output = torch.empty_like(hidden_states)
140+
141+
out_row_stride, out_col_stride = output.stride(0), output.stride(1)
142+
grid = (rows,)
143+
_rms_norm_high_accuracy_kernel[grid](
144+
hidden_states,
145+
weight,
146+
output,
147+
in_row_stride,
148+
in_col_stride,
149+
out_row_stride,
150+
out_col_stride,
151+
eps=eps,
152+
N_COLS=hidden_dim,
153+
BLOCK_N=BLOCK_N,
154+
num_warps=4,
155+
num_stages=3,
156+
)
157+
return output.reshape(origin_shape)
158+
159+
80160
def torch_rms_norm(x, weight, eps):
81161
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) * weight
82162

@@ -88,11 +168,20 @@ def test_rms_norm(M, N, dtype, eps=1e-5, device="cuda"):
88168
weight = torch.rand(w_shape, dtype=dtype, device="cuda")
89169
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
90170
# forward pass
91-
y_tri = rmsnorm_forward(x, weight, eps)
171+
y_tri = rmsnorm_forward_low_accuracy(x, weight, eps)
172+
y_tri_high_acc = rmsnorm_forward_high_accuracy(x, weight, eps)
92173
y_ref = torch_rms_norm(x.to(torch.float32), weight.to(torch.float32), eps).to(dtype)
93174

94175
# compare
95-
print("type:", y_tri.dtype, y_ref.dtype)
176+
print("type:", y_tri.dtype, y_ref.dtype, y_tri_high_acc.dtype)
96177
print("max delta:", torch.max(torch.abs(y_tri - y_ref)))
178+
print("max delta:", torch.max(torch.abs(y_tri_high_acc - y_ref)))
97179
assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0)
98180
return
181+
182+
use_high_acc = os.getenv("RMSNORM_HIGH_ACCURACY", "False").upper() in ["ON", "TRUE", "1"]
183+
184+
if use_high_acc:
185+
rmsnorm_forward = rmsnorm_forward_high_accuracy
186+
else:
187+
rmsnorm_forward = rmsnorm_forward_low_accuracy

lightllm/models/qwen3/layer_infer/transformer_layer_infer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,18 @@ def _get_qkv(
3636
input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_)
3737
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
3838

39-
rmsnorm_forward(
39+
q = rmsnorm_forward(
4040
q.view(-1, self.head_dim_),
4141
weight=layer_weight.q_norm_weight_.weight,
4242
eps=self.eps_,
43-
out=q.view(-1, self.head_dim_),
43+
use_custom_tensor_mananger=True
4444
)
4545

4646
cache_kv[:, : self.tp_k_head_num_, :] = rmsnorm_forward(
4747
cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, cache_kv.shape[-1]),
4848
weight=layer_weight.k_norm_weight_.weight,
4949
eps=self.eps_,
50+
use_custom_tensor_mananger=True
5051
).view(-1, self.tp_k_head_num_, cache_kv.shape[-1])
5152

5253
rotary_emb_fwd(

lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,17 +60,18 @@ def _get_qkv(
6060
cache_kv = layer_weight.kv_proj.mm(
6161
input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_)
6262
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
63-
rmsnorm_forward(
63+
q = rmsnorm_forward(
6464
q.view(-1, self.head_dim_),
6565
weight=layer_weight.q_norm_weight_.weight,
6666
eps=self.eps_,
67-
out=q.view(-1, self.head_dim_),
67+
use_custom_tensor_mananger=True
6868
)
6969

7070
cache_kv[:, : self.tp_k_head_num_, :] = rmsnorm_forward(
7171
cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, cache_kv.shape[-1]),
7272
weight=layer_weight.k_norm_weight_.weight,
7373
eps=self.eps_,
74+
use_custom_tensor_mananger=True
7475
).view(-1, self.tp_k_head_num_, cache_kv.shape[-1])
7576

7677
rotary_emb_fwd(

lightllm/models/vit/layer_infer/transformer_layer_infer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@
77
import triton
88

99
from lightllm.models.vit.layer_weights.transformer_layer_weight import ViTTransformerLayerWeight
10-
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward, torch_rms_norm
1110
from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd
1211
from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size
1312
from lightllm.models.vit.triton_kernel.gelu_vit import gelu_fwd
14-
from lightllm.models.vit.triton_kernel.rms_norm_vit import rms_norm
13+
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward_high_accuracy as rms_norm
1514
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
1615

1716

0 commit comments

Comments
 (0)