Skip to content

Commit 6655d1d

Browse files
author
sangchengmeng
committed
fix rms_norm
1 parent feb505b commit 6655d1d

File tree

10 files changed

+138
-135
lines changed

10 files changed

+138
-135
lines changed

lightllm/models/chatglm2/layer_infer/transformer_layer_infer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
from lightllm.models.chatglm2.triton_kernel.rotary_emb import rotary_emb_fwd
1212
from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv, destindex_copy_quantize_kv
13-
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward
1413

1514

1615
class ChatGLM2TransformerLayerInfer(LlamaTransformerLayerInfer):

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -154,16 +154,18 @@ def _get_qkv(
154154
q = layer_weight.q_weight_.mm(input)
155155
else:
156156
q = layer_weight.q_a_proj_.mm(input)
157-
rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_, out=q)
157+
q = rmsnorm_forward(
158+
q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_, use_custom_tensor_mananger=True
159+
)
158160
q = layer_weight.q_b_proj_.mm(q)
159161
q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim)
160162
q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
161163
layer_weight.kv_a_proj_with_mqa_.mm(input, out=cache_kv.view(-1, self.kv_lora_rank + self.qk_rope_head_dim))
162-
rmsnorm_forward(
164+
cache_kv[:, :, : self.kv_lora_rank] = rmsnorm_forward(
163165
cache_kv[:, :, : self.kv_lora_rank],
164166
weight=layer_weight.kv_a_layernorm_.weight,
165167
eps=self.eps_,
166-
out=cache_kv[:, :, : self.kv_lora_rank],
168+
use_custom_tensor_mananger=True,
167169
)
168170

169171
rotary_emb_fwd(
@@ -191,16 +193,16 @@ def _tpsp_get_qkv(
191193
q = layer_weight.q_weight_.mm(input)
192194
else:
193195
q = layer_weight.q_a_proj_.mm(input)
194-
rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_, out=q)
196+
q = rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_)
195197
q = layer_weight.q_b_proj_.mm(q)
196198
q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim)
197199
q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
198200
layer_weight.kv_a_proj_with_mqa_.mm(input, out=cache_kv.view(-1, self.kv_lora_rank + self.qk_rope_head_dim))
199-
rmsnorm_forward(
201+
cache_kv[:, :, : self.kv_lora_rank] = rmsnorm_forward(
200202
cache_kv[:, :, : self.kv_lora_rank],
201203
weight=layer_weight.kv_a_layernorm_.weight,
202204
eps=self.eps_,
203-
out=cache_kv[:, :, : self.kv_lora_rank],
205+
use_custom_tensor_mananger=True,
204206
)
205207
rotary_emb_fwd(
206208
q_rope,

lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,12 @@ def _mtp_context_forward(
2020
):
2121
tgt_embdings = infer_state.deepseekv3_mtp_draft_input_hiddens
2222
assert input_embdings.shape[0] == tgt_embdings.shape[0]
23-
rmsnorm_forward(input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, out=input_embdings)
24-
rmsnorm_forward(tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, out=tgt_embdings)
23+
input_embdings = rmsnorm_forward(
24+
input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, use_custom_tensor_mananger=True
25+
)
26+
tgt_embdings = rmsnorm_forward(
27+
tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, use_custom_tensor_mananger=True
28+
)
2529

2630
cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1)
2731

@@ -36,8 +40,12 @@ def _mtp_token_forward(
3640
):
3741
tgt_embdings = infer_state.deepseekv3_mtp_draft_input_hiddens
3842
assert input_embdings.shape[0] == tgt_embdings.shape[0]
39-
rmsnorm_forward(input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, out=input_embdings)
40-
rmsnorm_forward(tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, out=tgt_embdings)
43+
input_embdings = rmsnorm_forward(
44+
input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, use_custom_tensor_mananger=True
45+
)
46+
tgt_embdings = rmsnorm_forward(
47+
tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, use_custom_tensor_mananger=True
48+
)
4149

4250
cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1)
4351

lightllm/models/llama/layer_infer/post_layer_infer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88
from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight
99
from einops import rearrange
1010
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
11-
from lightllm.models.vit.triton_kernel.rms_norm_vit import rms_norm
11+
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward
1212
from lightllm.common.basemodel import PostLayerInferTpl
1313
from lightllm.utils.infer_utils import mark_cost_time
1414
from lightllm.distributed.communication_op import all_gather
1515

1616

17+
1718
class LlamaPostLayerInfer(PostLayerInferTpl):
1819
""" """
1920

@@ -25,7 +26,7 @@ def __init__(self, network_config, mode):
2526
return
2627

2728
def _norm(self, input, infer_state, layer_weight: LlamaPreAndPostLayerWeight) -> torch.Tensor:
28-
return rms_norm(input, layer_weight.final_norm_weight_, eps=self.eps_, use_custom_tensor_mananger=True)
29+
return rmsnorm_forward(input, layer_weight.final_norm_weight_, eps=self.eps_, use_custom_tensor_mananger=True)
2930

3031
def _slice_get_last_input(self, input_embdings, infer_state: LlamaInferStateInfo):
3132

lightllm/models/llama/layer_infer/transformer_layer_infer.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import torch
23
import triton
34
import torch.functional as F
@@ -14,7 +15,7 @@
1415
from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd, token_att_fwd_int8k
1516
from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd
1617
from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2, token_att_fwd2_int8v
17-
from lightllm.models.vit.triton_kernel.rms_norm_vit import rms_norm
18+
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward
1819
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd
1920
from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd
2021

@@ -134,16 +135,16 @@ def _bind_attention(self):
134135
def _att_norm(
135136
self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight
136137
) -> torch.Tensor:
137-
out = self.alloc_tensor(input.shape, input.dtype)
138-
rms_norm(input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_, use_custom_tensor_mananger=True)
139-
return out
138+
return rmsnorm_forward(
139+
input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_, use_custom_tensor_mananger=True
140+
)
140141

141142
def _ffn_norm(
142143
self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight
143144
) -> torch.Tensor:
144-
out = self.alloc_tensor(input.shape, input.dtype)
145-
rms_norm(input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_, use_custom_tensor_mananger=True)
146-
return out
145+
return rmsnorm_forward(
146+
input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_, use_custom_tensor_mananger=True
147+
)
147148

148149
def _get_qkv(
149150
self, input, cache_kv, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight

lightllm/models/llama/triton_kernel/rmsnorm.py

Lines changed: 100 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1+
import os
12
import torch
2-
33
import triton
44
import triton.language as tl
5+
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
56

67

78
@triton.jit
8-
def _rms_norm_fwd_fused(
9+
def _rms_norm_low_accuracy_kernel(
910
X, # pointer to the input
1011
Y, # pointer to the output
1112
W, # pointer to the weights
@@ -41,9 +42,15 @@ def _rms_norm_fwd_fused(
4142
tl.store(Y + cols * y_stride1, y.to(Y.dtype.element_ty), mask=mask)
4243

4344

44-
def rmsnorm_forward(x: torch.Tensor, weight, eps, out=None):
45+
def rmsnorm_forward_low_accuracy(x: torch.Tensor, weight, eps, use_custom_tensor_mananger: bool = False):
4546
# allocate output
46-
y = torch.empty_like(x) if out is None else out
47+
if use_custom_tensor_mananger:
48+
shape = x.shape
49+
dtype = x.dtype
50+
device = x.device
51+
y = g_cache_manager.alloc_tensor(shape, dtype, device=device)
52+
else:
53+
y = torch.empty_like(x)
4754
# reshape input data into 2D tensor
4855
x_arg = x.view(-1, x.shape[-1])
4956
y_arg = y.view(-1, x.shape[-1])
@@ -61,7 +68,7 @@ def rmsnorm_forward(x: torch.Tensor, weight, eps, out=None):
6168
if BLOCK_SIZE > 16384:
6269
BLOCK_SIZE = 16384
6370
# enqueue kernel
64-
_rms_norm_fwd_fused[(M,)](
71+
_rms_norm_low_accuracy_kernel[(M,)](
6572
x_arg,
6673
y_arg,
6774
weight,
@@ -77,6 +84,82 @@ def rmsnorm_forward(x: torch.Tensor, weight, eps, out=None):
7784
return y
7885

7986

87+
@triton.jit
88+
def _rms_norm_high_accuracy_kernel(
89+
input,
90+
weight,
91+
output,
92+
in_row_stride: tl.constexpr,
93+
in_col_stride: tl.constexpr,
94+
out_row_stride: tl.constexpr,
95+
out_col_stride: tl.constexpr,
96+
eps: tl.constexpr,
97+
N_COLS: tl.constexpr,
98+
BLOCK_N: tl.constexpr,
99+
):
100+
"""Rms norm kernel."""
101+
prog_id = tl.program_id(0)
102+
offsets = tl.arange(0, BLOCK_N)
103+
104+
w = tl.load(weight + offsets, mask=offsets < N_COLS, other=0.0)
105+
106+
x_ptr = input + prog_id * in_row_stride
107+
x = tl.load(x_ptr + offsets * in_col_stride, mask=offsets < N_COLS, other=0.0)
108+
xf = x.to(tl.float32)
109+
110+
var = tl.sum(xf * xf, 0) * float(1.0 / N_COLS)
111+
out = xf / tl.sqrt(var + eps)
112+
out = (w * out).to(x.dtype)
113+
114+
out_ptr = output + prog_id * out_row_stride
115+
tl.store(out_ptr + offsets * out_col_stride, out, mask=offsets < N_COLS)
116+
117+
118+
def rmsnorm_forward_high_accuracy(
119+
hidden_states: torch.Tensor, weight: torch.Tensor, eps: float = 1e-5, use_custom_tensor_mananger: bool = False
120+
):
121+
"""Rms norm."""
122+
123+
assert hidden_states.is_contiguous(), "hidden_states must be contiguous"
124+
125+
origin_shape = hidden_states.shape
126+
hidden_dim = weight.shape[0]
127+
assert hidden_dim == origin_shape[-1], f"hidden_dim {hidden_dim} != {origin_shape[-1]}"
128+
129+
rows = hidden_states.numel() // hidden_dim
130+
if hidden_states.dim() == 3: # (bs, seq_len, hidden_dim)
131+
hidden_states = hidden_states.view(rows, hidden_dim)
132+
133+
in_row_stride, in_col_stride = hidden_states.stride(0), hidden_states.stride(1)
134+
135+
BLOCK_N = triton.next_power_of_2(hidden_dim)
136+
if use_custom_tensor_mananger:
137+
shape = hidden_states.shape
138+
dtype = hidden_states.dtype
139+
device = hidden_states.device
140+
output = g_cache_manager.alloc_tensor(shape, dtype, device=device)
141+
else:
142+
output = torch.empty_like(hidden_states)
143+
144+
out_row_stride, out_col_stride = output.stride(0), output.stride(1)
145+
grid = (rows,)
146+
_rms_norm_high_accuracy_kernel[grid](
147+
hidden_states,
148+
weight,
149+
output,
150+
in_row_stride,
151+
in_col_stride,
152+
out_row_stride,
153+
out_col_stride,
154+
eps=eps,
155+
N_COLS=hidden_dim,
156+
BLOCK_N=BLOCK_N,
157+
num_warps=4,
158+
num_stages=3,
159+
)
160+
return output.reshape(origin_shape)
161+
162+
80163
def torch_rms_norm(x, weight, eps):
81164
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) * weight
82165

@@ -88,11 +171,21 @@ def test_rms_norm(M, N, dtype, eps=1e-5, device="cuda"):
88171
weight = torch.rand(w_shape, dtype=dtype, device="cuda")
89172
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
90173
# forward pass
91-
y_tri = rmsnorm_forward(x, weight, eps)
174+
y_tri = rmsnorm_forward_low_accuracy(x, weight, eps)
175+
y_tri_high_acc = rmsnorm_forward_high_accuracy(x, weight, eps)
92176
y_ref = torch_rms_norm(x.to(torch.float32), weight.to(torch.float32), eps).to(dtype)
93177

94178
# compare
95-
print("type:", y_tri.dtype, y_ref.dtype)
179+
print("type:", y_tri.dtype, y_ref.dtype, y_tri_high_acc.dtype)
96180
print("max delta:", torch.max(torch.abs(y_tri - y_ref)))
181+
print("max delta:", torch.max(torch.abs(y_tri_high_acc - y_ref)))
97182
assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0)
98183
return
184+
185+
186+
use_high_acc = os.getenv("RMSNORM_HIGH_ACCURACY", "False").upper() in ["ON", "TRUE", "1"]
187+
188+
if use_high_acc:
189+
rmsnorm_forward = rmsnorm_forward_high_accuracy
190+
else:
191+
rmsnorm_forward = rmsnorm_forward_low_accuracy

lightllm/models/qwen3/layer_infer/transformer_layer_infer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,18 @@ def _get_qkv(
3636
input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_)
3737
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
3838

39-
rmsnorm_forward(
39+
q = rmsnorm_forward(
4040
q.view(-1, self.head_dim_),
4141
weight=layer_weight.q_norm_weight_.weight,
4242
eps=self.eps_,
43-
out=q.view(-1, self.head_dim_),
43+
use_custom_tensor_mananger=True,
4444
)
4545

4646
cache_kv[:, : self.tp_k_head_num_, :] = rmsnorm_forward(
4747
cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, cache_kv.shape[-1]),
4848
weight=layer_weight.k_norm_weight_.weight,
4949
eps=self.eps_,
50+
use_custom_tensor_mananger=True,
5051
).view(-1, self.tp_k_head_num_, cache_kv.shape[-1])
5152

5253
rotary_emb_fwd(

lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,17 +60,18 @@ def _get_qkv(
6060
cache_kv = layer_weight.kv_proj.mm(
6161
input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_)
6262
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
63-
rmsnorm_forward(
63+
q = rmsnorm_forward(
6464
q.view(-1, self.head_dim_),
6565
weight=layer_weight.q_norm_weight_.weight,
6666
eps=self.eps_,
67-
out=q.view(-1, self.head_dim_),
67+
use_custom_tensor_mananger=True
6868
)
6969

7070
cache_kv[:, : self.tp_k_head_num_, :] = rmsnorm_forward(
7171
cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, cache_kv.shape[-1]),
7272
weight=layer_weight.k_norm_weight_.weight,
7373
eps=self.eps_,
74+
use_custom_tensor_mananger=True
7475
).view(-1, self.tp_k_head_num_, cache_kv.shape[-1])
7576

7677
rotary_emb_fwd(

lightllm/models/vit/layer_infer/transformer_layer_infer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@
77
import triton
88

99
from lightllm.models.vit.layer_weights.transformer_layer_weight import ViTTransformerLayerWeight
10-
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward, torch_rms_norm
1110
from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd
1211
from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size
1312
from lightllm.models.vit.triton_kernel.gelu_vit import gelu_fwd
14-
from lightllm.models.vit.triton_kernel.rms_norm_vit import rms_norm
13+
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward_high_accuracy as rms_norm
1514
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
1615

1716

0 commit comments

Comments
 (0)