1+ import torch
2+ from torch import nn
3+ from torch .nn import functional as F
4+ import numpy as np
5+ from functools import partial
6+ from typing import Optional
7+
8+ from lightllm .models .gpt_oss .layer_weights .transformer_layer_weight import GptOssTransformerLayerWeight
9+ from lightllm .models .llama .flashattention_infer_struct import FlashAttentionStateInfo
10+ from lightllm .models .llama .layer_infer .transformer_layer_infer import LlamaTransformerLayerInfer
11+ from lightllm .models .llama .layer_weights .transformer_layer_weight import LlamaTransformerLayerWeight
12+ from lightllm .utils .sgl_utils import flash_attn_with_kvcache
13+ from lightllm .utils .log_utils import init_logger
14+
15+ logger = init_logger (__name__ )
16+
17+ class GptOssTransformerLayerInfer (LlamaTransformerLayerInfer ):
18+ def __init__ (self , layer_num , network_config , mode = []):
19+ super ().__init__ (layer_num , network_config , mode )
20+ self .hidden_size = self .network_config_ ['hidden_size' ]
21+ self .alpha = 1.702
22+ self .limit = 7.0
23+ self .top_k = network_config ['num_experts_per_tok' ]
24+ self .sliding_window = network_config ['sliding_window' ]
25+ self .head_dim_ = network_config ["head_dim" ]
26+
27+ def _bind_attention (self ):
28+ self ._copy_kv_to_mem_cache = partial (LlamaTransformerLayerInfer ._copy_kv_to_mem_cache_normal , self )
29+ self ._context_attention_kernel = self ._conext_sliding_attention_flashattention
30+ self ._token_attention_kernel = self ._token_sliding_attention_flashattention
31+
32+ def _bind_norm (self ):
33+ self ._att_norm = self ._att_norm
34+ self ._ffn_norm = self ._ffn_norm
35+ return
36+
37+ def _experts (self , hidden_states : torch .Tensor , router_indices , routing_weights , layer_weight : GptOssTransformerLayerWeight ):
38+ batch_size = hidden_states .shape [0 ]
39+ hidden_states = hidden_states .reshape (- 1 , self .hidden_size ) # (num_tokens, hidden_size)
40+ num_experts = routing_weights .shape [1 ]
41+
42+ hidden_states = hidden_states .repeat (num_experts , 1 )
43+ hidden_states = hidden_states .view (num_experts , - 1 , self .hidden_size )
44+ gate_up = torch .bmm (hidden_states , layer_weight .gate_up_proj_weight ) + layer_weight .gate_up_proj_bias .weight [..., None , :]
45+ gate , up = gate_up [..., ::2 ], gate_up [..., 1 ::2 ]
46+ gate = gate .clamp (min = None , max = self .limit )
47+ up = up .clamp (min = - self .limit , max = self .limit )
48+ glu = gate * torch .sigmoid (gate * self .alpha )
49+ next_states = torch .bmm (((up + 1 ) * glu ), layer_weight .down_proj_weight )
50+ next_states = next_states + layer_weight .down_proj_bias .weight [..., None , :]
51+ next_states = next_states .view (num_experts , batch_size , - 1 , self .hidden_size )
52+ next_states = next_states * routing_weights .transpose (0 , 1 ).view (num_experts , batch_size , - 1 )[..., None ]
53+ next_states = next_states .sum (dim = 0 )
54+ return next_states
55+
56+ def _att_norm (
57+ self , input , infer_state , layer_weight
58+ ) -> torch .Tensor :
59+ out = self .alloc_tensor (input .shape , input .dtype )
60+ out = self ._gpt_oss_rmsnorm (input , weight = layer_weight .att_norm_weight_ .weight , eps = self .eps_ )
61+ return out
62+
63+ def _ffn_norm (
64+ self , input , infer_state , layer_weight
65+ ) -> torch .Tensor :
66+ out = self .alloc_tensor (input .shape , input .dtype )
67+ out = self ._gpt_oss_rmsnorm (input , weight = layer_weight .ffn_norm_weight_ .weight , eps = self .eps_ )
68+ return out
69+
70+ def _gpt_oss_rmsnorm (self , hidden_states , weight , eps = 1e-6 ):
71+ input_dtype = hidden_states .dtype
72+ hidden_states = hidden_states .to (torch .float32 )
73+ variance = hidden_states .pow (2 ).mean (- 1 , keepdim = True )
74+ hidden_states = hidden_states * torch .rsqrt (variance + eps )
75+ return (weight * hidden_states ).to (input_dtype ) # main diff with Llama
76+
77+ def _router (self , hidden_states , layer_weight : GptOssTransformerLayerWeight ):
78+ hidden_states = hidden_states .reshape (- 1 , self .hidden_size )
79+ router_logits = layer_weight .moe_gate .mm (hidden_states ) # (seq_len, num_experts)
80+ router_top_value , router_indices = torch .topk (router_logits , self .top_k , dim = - 1 ) # (seq_len, top_k)
81+ router_top_value = torch .nn .functional .softmax (router_top_value , dim = 1 , dtype = router_top_value .dtype )
82+ router_scores = torch .zeros_like (router_logits ).scatter_ (1 , router_indices , router_top_value )
83+ return router_scores , router_indices
84+
85+ def _ffn (self , input , infer_state : FlashAttentionStateInfo , layer_weight : GptOssTransformerLayerWeight ) -> torch .Tensor :
86+ router_scores , router_indices = self ._router (input , layer_weight ) # (num_experts, seq_len)
87+ routed_out = self ._experts (input , router_indices = router_indices , routing_weights = router_scores , layer_weight = layer_weight )
88+ return routed_out
89+
90+ def _conext_sliding_attention_flashattention (self , q , kv , infer_state : FlashAttentionStateInfo , layer_weight , out = None ):
91+ if self .network_config_ ['layer_types' ][self .layer_num_ ] == "sliding_attention" :
92+ window_size = (self .sliding_window - 1 , self .sliding_window - 1 )
93+ else :
94+ window_size = (- 1 , - 1 )
95+
96+ cache_k = infer_state .mem_manager .kv_buffer [self .layer_num_ ][:, 0 : self .tp_k_head_num_ , :].reshape (
97+ - 1 , 1 , self .tp_k_head_num_ , self .head_dim_
98+ )
99+ cache_v = infer_state .mem_manager .kv_buffer [self .layer_num_ ][
100+ :, self .tp_k_head_num_ : self .tp_k_head_num_ + self .tp_v_head_num_ , :
101+ ].reshape (- 1 , 1 , self .tp_v_head_num_ , self .head_dim_ )
102+ q = q .reshape (- 1 , self .tp_q_head_num_ , self .head_dim_ )
103+ k_descale , v_descale = None , None # disable quantization
104+ Lq = q .shape [- 1 ]
105+ sm_scale = 1.0 / (Lq ** 0.5 )
106+ o = flash_attn_with_kvcache (
107+ q = q ,
108+ k_cache = cache_k ,
109+ v_cache = cache_v ,
110+ page_table = infer_state .page_table ,
111+ cache_seqlens = infer_state .b_seq_len ,
112+ cu_seqlens_q = infer_state .cu_seqlens_q ,
113+ cu_seqlens_k_new = infer_state .cu_seqlens_k ,
114+ max_seqlen_q = infer_state .q_max_seq_len ,
115+ softmax_scale = sm_scale ,
116+ causal = True ,
117+ window_size = (- 1 , - 1 ),
118+ softcap = 0.0 ,
119+ k_descale = k_descale ,
120+ v_descale = v_descale ,
121+ return_softmax_lse = False ,
122+ sinks = layer_weight .attn_sinks .weight ,
123+ )
124+ return o
125+
126+ def _token_sliding_attention_flashattention (self , q , infer_state : FlashAttentionStateInfo , layer_weight , out = None ):
127+ if self .network_config_ ['layer_types' ][self .layer_num_ ] == "sliding_attention" :
128+ window_size = (self .sliding_window - 1 , self .sliding_window - 1 )
129+ else :
130+ window_size = (- 1 , - 1 )
131+
132+ cache_k = infer_state .mem_manager .kv_buffer [self .layer_num_ ][:, 0 : self .tp_k_head_num_ , :].reshape (
133+ - 1 , 1 , self .tp_k_head_num_ , self .head_dim_
134+ )
135+ cache_v = infer_state .mem_manager .kv_buffer [self .layer_num_ ][
136+ :, self .tp_k_head_num_ : self .tp_k_head_num_ + self .tp_v_head_num_ , :
137+ ].reshape (- 1 , 1 , self .tp_v_head_num_ , self .head_dim_ )
138+ q = q .reshape (- 1 , self .tp_q_head_num_ , self .head_dim_ )
139+ k_descale , v_descale = None , None # disable quantization
140+ Lq = q .shape [- 1 ]
141+ sm_scale = 1.0 / (Lq ** 0.5 )
142+ o = flash_attn_with_kvcache (
143+ q = q ,
144+ k_cache = cache_k ,
145+ v_cache = cache_v ,
146+ page_table = infer_state .page_table ,
147+ cache_seqlens = infer_state .b_seq_len ,
148+ cu_seqlens_q = infer_state .cu_seqlens_q ,
149+ cu_seqlens_k_new = infer_state .cu_seqlens_k ,
150+ max_seqlen_q = 1 ,
151+ softmax_scale = sm_scale ,
152+ causal = True ,
153+ window_size = window_size ,
154+ softcap = 0.0 ,
155+ k_descale = k_descale ,
156+ v_descale = v_descale ,
157+ return_softmax_lse = False ,
158+ sinks = layer_weight .attn_sinks .weight ,
159+ )
160+ return o
0 commit comments