11import os
22import torch
3- import torch .functional as F
3+ import torch .nn . functional as F
44import torch .distributed as dist
55import numpy as np
66import triton
1111from lightllm .utils .log_utils import init_logger
1212from lightllm .utils .dist_utils import get_global_world_size
1313from lightllm .models .llama .triton_kernel .rotary_emb import rotary_emb_fwd
14+ from lightllm .common .fused_moe .moe_silu_and_mul import silu_and_mul_fwd
1415from lightllm .models .qwen3next .mem_manager import Qwen3NextMemoryManager
1516from lightllm .models .llama .infer_struct import LlamaInferStateInfo
1617from lightllm .distributed .communication_op import all_gather_into_tensor , reduce_scatter_tensor
@@ -33,9 +34,26 @@ def __init__(self, layer_num, network_config, mode=[]):
3334 self .is_linear = (layer_num + 1 ) % network_config ["full_attention_interval" ] != 0
3435 if self .is_linear :
3536 self .linear_attn_infer = Qwen3NextGatedDeltaNetInfer (network_config , layer_num , self .tp_world_size_ )
37+ return
3638
39+ @override
40+ def _bind_norm (self ):
41+ self ._att_norm = partial (Qwen3MOETransformerLayerInfer ._att_norm , self )
42+ self ._ffn_norm = partial (Qwen3MOETransformerLayerInfer ._ffn_norm , self )
3743 return
3844
45+ def _ffn_with_shared_expert (
46+ self , input , infer_state : LlamaInferStateInfo , layer_weight : Qwen3NextTransformerLayerWeight
47+ ) -> torch .Tensor :
48+ input = input .view (- 1 , self .embed_dim_ )
49+ up_gate_out = layer_weight .shared_expert_gate_up_proj .mm (input )
50+ ffn1_out = self .alloc_tensor ((input .size (0 ), up_gate_out .size (1 ) // 2 ), input .dtype )
51+ silu_and_mul_fwd (up_gate_out , ffn1_out )
52+ ffn2_out = layer_weight .shared_expert_down_proj .mm (ffn1_out )
53+ shared_expert_out = F .sigmoid (layer_weight .shared_expert_gate .mm (input )) * ffn2_out
54+ moe_out = self ._ffn (input , infer_state , layer_weight )
55+ return shared_expert_out + moe_out
56+
3957 @override
4058 def rmsnorm (self , input , weight , out : torch .Tensor ):
4159 # Zero-Centered RMSNorm TODO trion op
@@ -50,11 +68,8 @@ def rmsnorm(self, input, weight, out: torch.Tensor):
5068 def _get_o (
5169 self , input , infer_state : LlamaInferStateInfo , layer_weight : Qwen3NextTransformerLayerWeight
5270 ) -> torch .Tensor :
53- # TODO fuse it
54- input = input .view (- 1 , self .tp_o_head_num_ , self .head_dim_ )
5571 input = input * layer_weight ._gate
5672 layer_weight ._gate = None
57- input = input .reshape (- 1 , self .tp_o_head_num_ * self .head_dim_ )
5873 o_tensor = layer_weight .o_proj .mm (input )
5974 return o_tensor
6075
@@ -78,15 +93,13 @@ def context_forward(
7893 if self .is_linear :
7994 o = self .linear_attn_infer ._linear_attn (input1 , infer_state , layer_weight , is_prefill = True , infer_cls = self )
8095 else :
81- layer_weight ._gate = torch .sigmoid (layer_weight .o_gate_proj .mm (input1 )).view (
82- - 1 , self .tp_o_head_num_ , self .head_dim_
83- )
96+ layer_weight ._gate = torch .sigmoid (layer_weight .o_gate_proj .mm (input1 ))
8497 o = self .context_attention_forward (input1 , infer_state , layer_weight )
8598 input_embdings .add_ (o .view (- 1 , self .embed_dim_ ))
8699 o = None
87100
88101 input1 = self ._ffn_norm (input_embdings , infer_state , layer_weight )
89- ffn_out = self ._ffn (input1 , infer_state , layer_weight )
102+ ffn_out = self ._ffn_with_shared_expert (input1 , infer_state , layer_weight )
90103 input1 = None
91104 if self .tp_world_size_ > 1 :
92105 all_reduce (ffn_out , op = dist .ReduceOp .SUM , group = infer_state .dist_group , async_op = False )
@@ -113,15 +126,13 @@ def token_forward(
113126 if self .is_linear :
114127 o = self .linear_attn_infer ._linear_attn (input1 , infer_state , layer_weight , is_prefill = False , infer_cls = self )
115128 else :
116- layer_weight ._gate = torch .sigmoid (layer_weight .o_gate_proj .mm (input1 )).view (
117- - 1 , self .tp_o_head_num_ , self .head_dim_
118- )
129+ layer_weight ._gate = torch .sigmoid (layer_weight .o_gate_proj .mm (input1 ))
119130 o = self .token_attention_forward (input1 , infer_state , layer_weight )
120131 input_embdings .add_ (o .view (- 1 , self .embed_dim_ ))
121132 o = None
122133
123134 input1 = self ._ffn_norm (input_embdings , infer_state , layer_weight )
124- ffn_out = self ._ffn (input1 , infer_state , layer_weight )
135+ ffn_out = self ._ffn_with_shared_expert (input1 , infer_state , layer_weight )
125136 input1 = None
126137 if self .tp_world_size_ > 1 :
127138 all_reduce (ffn_out , op = dist .ReduceOp .SUM , group = infer_state .dist_group , async_op = False )
@@ -206,20 +217,14 @@ def _linear_attn(
206217 assert isinstance (infer_state .mem_manager , Qwen3NextMemoryManager )
207218 input = input .view (- 1 , infer_cls .embed_dim_ )
208219
209- # Get conv_states and ssm_states buffer
210220 conv_states , ssm_states = infer_state .mem_manager .get_mamba_state_buffer (self .layer_idx_ )
211221
212- # Project input to qkvzba
213- mixed_qkvzba = layer_weight .linear_in_proj .mm (
214- input
215- ) # tgt: [batch_size, (self.key_dim * 2 + self.value_dim * 2) + (self.num_v_heads * 2)]
222+ mixed_qkvzba = layer_weight .linear_in_proj .mm (input )
216223 q , k , v , z , b , a = self ._fix_query_key_value_ba_ordering (mixed_qkvzba )
217- mixed_qkv = torch .cat ([q , k , v ], dim = - 1 ) # tgt: [batch_size, tp_qkv_dim]
224+ mixed_qkv = torch .cat ([q , k , v ], dim = - 1 )
218225
219- # Convolution: different paths for prefill and decode
220226 if is_prefill :
221- # Prefill: use causal_conv1d_fn for full sequence processing
222- mixed_qkv = mixed_qkv .transpose (0 , 1 ) # [tp_qkv_dim, seq_len]
227+ mixed_qkv = mixed_qkv .transpose (0 , 1 )
223228 out_tensor = infer_cls .alloc_tensor (mixed_qkv .shape , mixed_qkv .dtype , device = mixed_qkv .device )
224229 causal_conv1d_fn (
225230 mixed_qkv ,
@@ -229,12 +234,10 @@ def _linear_attn(
229234 infer_state .b1_cu_q_seq_len ,
230235 out = out_tensor ,
231236 cache_indices = infer_state .b_req_idx ,
232- activation = self .activation , # 添加 activation 参数
237+ activation = self .activation ,
233238 )
234- mixed_qkv = out_tensor .transpose (0 , 1 ) # [seq_len, tp_qkv_dim]
239+ mixed_qkv = out_tensor .transpose (0 , 1 )
235240 else :
236- # Decode: use causal_conv1d_update for single token update
237- # Need to transpose conv_states to match expected format: (..., dim, state_len)
238241 mixed_qkv = causal_conv1d_update (
239242 mixed_qkv ,
240243 conv_states .transpose (1 , 2 ),
@@ -253,12 +256,9 @@ def _linear_attn(
253256 g = fused_gdn_gating (layer_weight .linear_A_log .weight , a , layer_weight .linear_dt_bias .weight )
254257 g , beta = map (lambda x : rearrange (x , "l d -> 1 l d" ), (g , beta ))
255258
256- # Recurrent attention: different paths for prefill and decode
257259 if is_prefill :
258- # Prefill: use chunk_gated_delta_rule
259- # Get initial state and clear it for new requests (no prompt cache support yet)
260260 initial_state = ssm_states [infer_state .b_req_idx ].contiguous ()
261- initial_state [...] = 0 # Clear initial state for all requests
261+ initial_state [...] = 0
262262 (core_attn_out , last_recurrent_state ,) = chunk_gated_delta_rule (
263263 q = query ,
264264 k = key ,
@@ -274,7 +274,6 @@ def _linear_attn(
274274 # Update SSM state with final state
275275 ssm_states [infer_state .b_req_idx , ...] = last_recurrent_state .to (ssm_states .dtype )
276276 else :
277- # Decode: use fused_recurrent_gated_delta_rule for single token
278277 batch_size = input .shape [0 ]
279278 cu_seqlens = torch .arange (0 , batch_size + 1 , dtype = torch .int32 , device = input .device )
280279 (core_attn_out , last_recurrent_state ,) = fused_recurrent_gated_delta_rule (
@@ -290,7 +289,6 @@ def _linear_attn(
290289 use_qk_l2norm_in_kernel = True ,
291290 )
292291
293- # Gated RMSNorm and output projection
294292 z_shape_og = z .shape
295293 core_attn_out = core_attn_out .reshape (- 1 , core_attn_out .shape [- 1 ])
296294 z = z .reshape (- 1 , z .shape [- 1 ])
0 commit comments