Skip to content

Commit 4691c71

Browse files
committed
fix and clean the code.
1 parent 2302aee commit 4691c71

File tree

4 files changed

+71
-48
lines changed

4 files changed

+71
-48
lines changed

lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def _bind_ffn(self):
5252
self._ffn = partial(LlamaTransformerLayerInfer._ffn, self)
5353

5454
def rmsnorm(self, input, weight, out):
55-
return rmsnorm_forward(input, weight, self.eps_, out=input)
55+
return rmsnorm_forward(input, weight, self.eps_, out=out)
5656

5757
def _get_qkv(
5858
self,

lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
import torch
3-
import torch.functional as F
3+
import torch.nn.functional as F
44
import torch.distributed as dist
55
import numpy as np
66
import triton
@@ -11,6 +11,7 @@
1111
from lightllm.utils.log_utils import init_logger
1212
from lightllm.utils.dist_utils import get_global_world_size
1313
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd
14+
from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd
1415
from lightllm.models.qwen3next.mem_manager import Qwen3NextMemoryManager
1516
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
1617
from lightllm.distributed.communication_op import all_gather_into_tensor, reduce_scatter_tensor
@@ -33,9 +34,26 @@ def __init__(self, layer_num, network_config, mode=[]):
3334
self.is_linear = (layer_num + 1) % network_config["full_attention_interval"] != 0
3435
if self.is_linear:
3536
self.linear_attn_infer = Qwen3NextGatedDeltaNetInfer(network_config, layer_num, self.tp_world_size_)
37+
return
3638

39+
@override
40+
def _bind_norm(self):
41+
self._att_norm = partial(Qwen3MOETransformerLayerInfer._att_norm, self)
42+
self._ffn_norm = partial(Qwen3MOETransformerLayerInfer._ffn_norm, self)
3743
return
3844

45+
def _ffn_with_shared_expert(
46+
self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight
47+
) -> torch.Tensor:
48+
input = input.view(-1, self.embed_dim_)
49+
up_gate_out = layer_weight.shared_expert_gate_up_proj.mm(input)
50+
ffn1_out = self.alloc_tensor((input.size(0), up_gate_out.size(1) // 2), input.dtype)
51+
silu_and_mul_fwd(up_gate_out, ffn1_out)
52+
ffn2_out = layer_weight.shared_expert_down_proj.mm(ffn1_out)
53+
shared_expert_out = F.sigmoid(layer_weight.shared_expert_gate.mm(input)) * ffn2_out
54+
moe_out = self._ffn(input, infer_state, layer_weight)
55+
return shared_expert_out + moe_out
56+
3957
@override
4058
def rmsnorm(self, input, weight, out: torch.Tensor):
4159
# Zero-Centered RMSNorm TODO trion op
@@ -50,11 +68,8 @@ def rmsnorm(self, input, weight, out: torch.Tensor):
5068
def _get_o(
5169
self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight
5270
) -> torch.Tensor:
53-
# TODO fuse it
54-
input = input.view(-1, self.tp_o_head_num_, self.head_dim_)
5571
input = input * layer_weight._gate
5672
layer_weight._gate = None
57-
input = input.reshape(-1, self.tp_o_head_num_ * self.head_dim_)
5873
o_tensor = layer_weight.o_proj.mm(input)
5974
return o_tensor
6075

@@ -78,15 +93,13 @@ def context_forward(
7893
if self.is_linear:
7994
o = self.linear_attn_infer._linear_attn(input1, infer_state, layer_weight, is_prefill=True, infer_cls=self)
8095
else:
81-
layer_weight._gate = torch.sigmoid(layer_weight.o_gate_proj.mm(input1)).view(
82-
-1, self.tp_o_head_num_, self.head_dim_
83-
)
96+
layer_weight._gate = torch.sigmoid(layer_weight.o_gate_proj.mm(input1))
8497
o = self.context_attention_forward(input1, infer_state, layer_weight)
8598
input_embdings.add_(o.view(-1, self.embed_dim_))
8699
o = None
87100

88101
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
89-
ffn_out = self._ffn(input1, infer_state, layer_weight)
102+
ffn_out = self._ffn_with_shared_expert(input1, infer_state, layer_weight)
90103
input1 = None
91104
if self.tp_world_size_ > 1:
92105
all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
@@ -113,15 +126,13 @@ def token_forward(
113126
if self.is_linear:
114127
o = self.linear_attn_infer._linear_attn(input1, infer_state, layer_weight, is_prefill=False, infer_cls=self)
115128
else:
116-
layer_weight._gate = torch.sigmoid(layer_weight.o_gate_proj.mm(input1)).view(
117-
-1, self.tp_o_head_num_, self.head_dim_
118-
)
129+
layer_weight._gate = torch.sigmoid(layer_weight.o_gate_proj.mm(input1))
119130
o = self.token_attention_forward(input1, infer_state, layer_weight)
120131
input_embdings.add_(o.view(-1, self.embed_dim_))
121132
o = None
122133

123134
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
124-
ffn_out = self._ffn(input1, infer_state, layer_weight)
135+
ffn_out = self._ffn_with_shared_expert(input1, infer_state, layer_weight)
125136
input1 = None
126137
if self.tp_world_size_ > 1:
127138
all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
@@ -206,20 +217,14 @@ def _linear_attn(
206217
assert isinstance(infer_state.mem_manager, Qwen3NextMemoryManager)
207218
input = input.view(-1, infer_cls.embed_dim_)
208219

209-
# Get conv_states and ssm_states buffer
210220
conv_states, ssm_states = infer_state.mem_manager.get_mamba_state_buffer(self.layer_idx_)
211221

212-
# Project input to qkvzba
213-
mixed_qkvzba = layer_weight.linear_in_proj.mm(
214-
input
215-
) # tgt: [batch_size, (self.key_dim * 2 + self.value_dim * 2) + (self.num_v_heads * 2)]
222+
mixed_qkvzba = layer_weight.linear_in_proj.mm(input)
216223
q, k, v, z, b, a = self._fix_query_key_value_ba_ordering(mixed_qkvzba)
217-
mixed_qkv = torch.cat([q, k, v], dim=-1) # tgt: [batch_size, tp_qkv_dim]
224+
mixed_qkv = torch.cat([q, k, v], dim=-1)
218225

219-
# Convolution: different paths for prefill and decode
220226
if is_prefill:
221-
# Prefill: use causal_conv1d_fn for full sequence processing
222-
mixed_qkv = mixed_qkv.transpose(0, 1) # [tp_qkv_dim, seq_len]
227+
mixed_qkv = mixed_qkv.transpose(0, 1)
223228
out_tensor = infer_cls.alloc_tensor(mixed_qkv.shape, mixed_qkv.dtype, device=mixed_qkv.device)
224229
causal_conv1d_fn(
225230
mixed_qkv,
@@ -229,12 +234,10 @@ def _linear_attn(
229234
infer_state.b1_cu_q_seq_len,
230235
out=out_tensor,
231236
cache_indices=infer_state.b_req_idx,
232-
activation=self.activation, # 添加 activation 参数
237+
activation=self.activation,
233238
)
234-
mixed_qkv = out_tensor.transpose(0, 1) # [seq_len, tp_qkv_dim]
239+
mixed_qkv = out_tensor.transpose(0, 1)
235240
else:
236-
# Decode: use causal_conv1d_update for single token update
237-
# Need to transpose conv_states to match expected format: (..., dim, state_len)
238241
mixed_qkv = causal_conv1d_update(
239242
mixed_qkv,
240243
conv_states.transpose(1, 2),
@@ -253,12 +256,9 @@ def _linear_attn(
253256
g = fused_gdn_gating(layer_weight.linear_A_log.weight, a, layer_weight.linear_dt_bias.weight)
254257
g, beta = map(lambda x: rearrange(x, "l d -> 1 l d"), (g, beta))
255258

256-
# Recurrent attention: different paths for prefill and decode
257259
if is_prefill:
258-
# Prefill: use chunk_gated_delta_rule
259-
# Get initial state and clear it for new requests (no prompt cache support yet)
260260
initial_state = ssm_states[infer_state.b_req_idx].contiguous()
261-
initial_state[...] = 0 # Clear initial state for all requests
261+
initial_state[...] = 0
262262
(core_attn_out, last_recurrent_state,) = chunk_gated_delta_rule(
263263
q=query,
264264
k=key,
@@ -274,7 +274,6 @@ def _linear_attn(
274274
# Update SSM state with final state
275275
ssm_states[infer_state.b_req_idx, ...] = last_recurrent_state.to(ssm_states.dtype)
276276
else:
277-
# Decode: use fused_recurrent_gated_delta_rule for single token
278277
batch_size = input.shape[0]
279278
cu_seqlens = torch.arange(0, batch_size + 1, dtype=torch.int32, device=input.device)
280279
(core_attn_out, last_recurrent_state,) = fused_recurrent_gated_delta_rule(
@@ -290,7 +289,6 @@ def _linear_attn(
290289
use_qk_l2norm_in_kernel=True,
291290
)
292291

293-
# Gated RMSNorm and output projection
294292
z_shape_og = z.shape
295293
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
296294
z = z.reshape(-1, z.shape[-1])

lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,8 @@ def _parse_config(self):
3535

3636
@override
3737
def _init_weight(self):
38-
if self.is_moe:
39-
self._init_moe()
40-
else:
41-
self._init_ffn()
38+
self._init_moe()
39+
self._init_shared_expert_weight()
4240

4341
self.att_norm_weight_ = NormWeight(
4442
self._att_norm_weight_name, self.data_type_, bias_name=self._att_norm_bias_name
@@ -80,14 +78,48 @@ def load_hf_weights(self, weights):
8078
self._split_q_with_gate(weights)
8179
super().load_hf_weights(weights)
8280

81+
def _init_shared_expert_weight(self):
82+
prefix = f"model.layers.{self.layer_num_}.mlp.shared_expert"
83+
self.shared_expert_gate_up_proj = MultiROWMMWeight(
84+
weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"],
85+
data_type=self.data_type_,
86+
quant_cfg=self.quant_cfg,
87+
layer_num=self.layer_num_,
88+
name="shared_expert_gate_up_proj",
89+
)
90+
self.shared_expert_down_proj = COLMMWeight(
91+
weight_name=f"{prefix}.down_proj.weight",
92+
data_type=self.data_type_,
93+
quant_cfg=self.quant_cfg,
94+
layer_num=self.layer_num_,
95+
name="shared_expert_down_proj",
96+
)
97+
self.shared_expert_gate = ROWMMWeight(
98+
weight_name=f"model.layers.{self.layer_num_}.mlp.shared_expert_gate.weight",
99+
data_type=self.data_type_,
100+
bias_name=None,
101+
quant_cfg=self.quant_cfg,
102+
layer_num=self.layer_num_,
103+
name="shared_expert_gate",
104+
tp_rank=0,
105+
tp_world_size=1,
106+
)
107+
83108
def _split_q_with_gate(self, weights):
84109
if self.q_proj.weight_name in weights:
85-
q_size = self.tp_q_head_num_ * self.head_dim * self.tp_world_size_
86-
_q_proj, _gate_proj = torch.split(weights[self.q_proj.weight_name], [q_size, q_size], dim=0)
110+
weight = weights[self.q_proj.weight_name]
111+
num_heads = self.tp_q_head_num_ * self.tp_world_size_
112+
weight = weight.view(num_heads * 2, self.head_dim, -1)
113+
_q_proj = weight[0::2].reshape(-1, weight.shape[-1])
114+
_gate_proj = weight[1::2].reshape(-1, weight.shape[-1])
87115
weights[self.q_proj.weight_name] = _q_proj
88116
weights[self.o_gate_proj.weight_name] = _gate_proj
89117
if self.q_proj.bias_name in weights:
90-
_q_proj, _gate_proj = torch.split(weights[self.q_proj.bias_name], [q_size, q_size], dim=0)
118+
bias = weights[self.q_proj.bias_name]
119+
num_heads = self.tp_q_head_num_ * self.tp_world_size_
120+
bias = bias.view(num_heads * 2, self.head_dim)
121+
_q_proj = bias[0::2].reshape(-1)
122+
_gate_proj = bias[1::2].reshape(-1)
91123
weights[self.q_proj.bias_name] = _q_proj
92124
weights[self.o_gate_proj.bias_name] = _gate_proj
93125

@@ -117,7 +149,6 @@ def _init_linear_weight(self):
117149
self.linear_conv1d = ROWMMWeight(
118150
weight_name=f"{prefix}.conv1d.weight",
119151
data_type=self.data_type_,
120-
bias_name=f"{prefix}.conv1d.bias",
121152
quant_cfg=self.quant_cfg,
122153
layer_num=self.layer_num_,
123154
name="conv1d_weight",
@@ -126,7 +157,6 @@ def _init_linear_weight(self):
126157
self.linear_in_proj = MultiROWMMWeight(
127158
weight_names=[f"{prefix}.in_proj_qkvz.weight", f"{prefix}.in_proj_ba.weight"],
128159
data_type=self.data_type_,
129-
bias_names=[None],
130160
quant_cfg=self.quant_cfg,
131161
layer_num=self.layer_num_,
132162
name="in_proj_weight",
@@ -142,20 +172,17 @@ def _init_linear_weight(self):
142172

143173
self.linear_dt_bias = TpParameterWeight(
144174
weight_name=f"{prefix}.dt_bias",
145-
data_type=self.data_type_,
175+
data_type=torch.float32,
146176
split_n_embed=self.linear_num_v_heads // self.tp_world_size_,
147-
bias_name=None,
148177
)
149178

150179
self.linear_A_log = TpParameterWeight(
151180
weight_name=f"{prefix}.A_log",
152-
data_type=self.data_type_,
181+
data_type=torch.float32,
153182
split_n_embed=self.linear_num_v_heads // self.tp_world_size_,
154-
bias_name=None,
155183
)
156184

157185
self.linear_norm = NormWeight(
158186
weight_name=f"{prefix}.norm.weight",
159187
data_type=self.data_type_,
160-
bias_name=None,
161188
)

lightllm/models/qwen3next/model.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
from typing_extensions import override
44
from lightllm.models.registry import ModelRegistry
55
from lightllm.models.qwen3_moe.model import Qwen3MOEModel
6-
from lightllm.models.qwen3next.layer_weights.gdn_layer_weight import Qwen3NextGatedDeltaNetWeight
7-
from lightllm.models.qwen3next.layer_infer.gdn_layer_infer import Qwen3NextGatedDeltaNetInfer
86
from lightllm.models.qwen3next.layer_weights.transformer_layer_weight import Qwen3NextTransformerLayerWeight
97
from lightllm.models.qwen3next.layer_infer.transformer_layer_infer import Qwen3NextTransformerLayerInfer
108
from lightllm.utils.log_utils import init_logger

0 commit comments

Comments
 (0)