Skip to content

Commit c46810b

Browse files
authored
fix qwen3 infer bug. (#878)
1 parent f3d8e61 commit c46810b

File tree

3 files changed

+16
-15
lines changed

3 files changed

+16
-15
lines changed

lightllm/models/llama/triton_kernel/rmsnorm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ def rmsnorm_forward(x: torch.Tensor, weight, eps, out=None):
5757
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
5858
# heuristics for number of warps
5959
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
60-
# print(BLOCK_SIZE, num_warps, "block_size, numwarps")
61-
BLOCK_SIZE = 128 * 2 * 2 * 2 * 2 * 2 * 2 * 2
62-
num_warps = 8
60+
num_warps = triton.next_power_of_2(num_warps)
61+
if BLOCK_SIZE > 16384:
62+
BLOCK_SIZE = 16384
6363
# enqueue kernel
6464
_rms_norm_fwd_fused[(M,)](
6565
x_arg,

lightllm/models/qwen3/layer_infer/transformer_layer_infer.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,19 +35,20 @@ def _get_qkv(
3535
cache_kv = layer_weight.kv_proj.mm(
3636
input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_)
3737
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
38+
3839
rmsnorm_forward(
39-
q.reshape(-1, self.head_dim_),
40+
q.view(-1, self.head_dim_),
4041
weight=layer_weight.q_norm_weight_.weight,
4142
eps=self.eps_,
42-
out=q.reshape(-1, self.head_dim_),
43+
out=q.view(-1, self.head_dim_),
4344
)
4445

45-
rmsnorm_forward(
46-
cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, self.head_dim_),
46+
cache_kv[:, : self.tp_k_head_num_, :] = rmsnorm_forward(
47+
cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, cache_kv.shape[-1]),
4748
weight=layer_weight.k_norm_weight_.weight,
4849
eps=self.eps_,
49-
out=cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, self.head_dim_),
50-
)
50+
).view(-1, self.tp_k_head_num_, cache_kv.shape[-1])
51+
5152
rotary_emb_fwd(
5253
q.view(-1, self.tp_q_head_num_, self.head_dim_),
5354
cache_kv[:, : self.tp_k_head_num_, :],

lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,18 +61,18 @@ def _get_qkv(
6161
input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_)
6262
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
6363
rmsnorm_forward(
64-
q.reshape(-1, self.head_dim_),
64+
q.view(-1, self.head_dim_),
6565
weight=layer_weight.q_norm_weight_.weight,
6666
eps=self.eps_,
67-
out=q.reshape(-1, self.head_dim_),
67+
out=q.view(-1, self.head_dim_),
6868
)
6969

70-
rmsnorm_forward(
71-
cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, self.head_dim_),
70+
cache_kv[:, : self.tp_k_head_num_, :] = rmsnorm_forward(
71+
cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, cache_kv.shape[-1]),
7272
weight=layer_weight.k_norm_weight_.weight,
7373
eps=self.eps_,
74-
out=cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, self.head_dim_),
75-
)
74+
).view(-1, self.tp_k_head_num_, cache_kv.shape[-1])
75+
7676
rotary_emb_fwd(
7777
q.view(-1, self.tp_q_head_num_, self.head_dim_),
7878
cache_kv[:, : self.tp_k_head_num_, :],

0 commit comments

Comments
 (0)