1414
1515logger = init_logger (__name__ )
1616
17+
1718class GptOssTransformerLayerInfer (LlamaTransformerLayerInfer ):
1819 def __init__ (self , layer_num , network_config , mode = []):
1920 super ().__init__ (layer_num , network_config , mode )
20- self .hidden_size = self .network_config_ [' hidden_size' ]
21+ self .hidden_size = self .network_config_ [" hidden_size" ]
2122 self .alpha = 1.702
2223 self .limit = 7.0
23- self .top_k = network_config [' num_experts_per_tok' ]
24- self .sliding_window = network_config [' sliding_window' ]
24+ self .top_k = network_config [" num_experts_per_tok" ]
25+ self .sliding_window = network_config [" sliding_window" ]
2526 self .head_dim_ = network_config ["head_dim" ]
2627
2728 def _bind_attention (self ):
2829 self ._copy_kv_to_mem_cache = partial (LlamaTransformerLayerInfer ._copy_kv_to_mem_cache_normal , self )
2930 self ._context_attention_kernel = self ._conext_sliding_attention_flashattention
3031 self ._token_attention_kernel = self ._token_sliding_attention_flashattention
31-
32+
3233 def _bind_norm (self ):
3334 self ._att_norm = self ._att_norm
3435 self ._ffn_norm = self ._ffn_norm
3536 return
3637
37- def _experts (self , hidden_states : torch .Tensor , router_indices , routing_weights , layer_weight : GptOssTransformerLayerWeight ):
38+ def _experts (
39+ self , hidden_states : torch .Tensor , router_indices , routing_weights , layer_weight : GptOssTransformerLayerWeight
40+ ):
3841 batch_size = hidden_states .shape [0 ]
3942 hidden_states = hidden_states .reshape (- 1 , self .hidden_size ) # (num_tokens, hidden_size)
4043 num_experts = routing_weights .shape [1 ]
4144
4245 hidden_states = hidden_states .repeat (num_experts , 1 )
4346 hidden_states = hidden_states .view (num_experts , - 1 , self .hidden_size )
44- gate_up = torch .bmm (hidden_states , layer_weight .gate_up_proj_weight ) + layer_weight .gate_up_proj_bias .weight [..., None , :]
47+ gate_up = (
48+ torch .bmm (hidden_states , layer_weight .gate_up_proj_weight )
49+ + layer_weight .gate_up_proj_bias .weight [..., None , :]
50+ )
4551 gate , up = gate_up [..., ::2 ], gate_up [..., 1 ::2 ]
4652 gate = gate .clamp (min = None , max = self .limit )
4753 up = up .clamp (min = - self .limit , max = self .limit )
@@ -52,21 +58,17 @@ def _experts(self, hidden_states: torch.Tensor, router_indices, routing_weights,
5258 next_states = next_states * routing_weights .transpose (0 , 1 ).view (num_experts , batch_size , - 1 )[..., None ]
5359 next_states = next_states .sum (dim = 0 )
5460 return next_states
55-
56- def _att_norm (
57- self , input , infer_state , layer_weight
58- ) -> torch .Tensor :
61+
62+ def _att_norm (self , input , infer_state , layer_weight ) -> torch .Tensor :
5963 out = self .alloc_tensor (input .shape , input .dtype )
6064 out = self ._gpt_oss_rmsnorm (input , weight = layer_weight .att_norm_weight_ .weight , eps = self .eps_ )
6165 return out
62-
63- def _ffn_norm (
64- self , input , infer_state , layer_weight
65- ) -> torch .Tensor :
66+
67+ def _ffn_norm (self , input , infer_state , layer_weight ) -> torch .Tensor :
6668 out = self .alloc_tensor (input .shape , input .dtype )
6769 out = self ._gpt_oss_rmsnorm (input , weight = layer_weight .ffn_norm_weight_ .weight , eps = self .eps_ )
6870 return out
69-
71+
7072 def _gpt_oss_rmsnorm (self , hidden_states , weight , eps = 1e-6 ):
7173 input_dtype = hidden_states .dtype
7274 hidden_states = hidden_states .to (torch .float32 )
@@ -81,18 +83,24 @@ def _router(self, hidden_states, layer_weight: GptOssTransformerLayerWeight):
8183 router_top_value = torch .nn .functional .softmax (router_top_value , dim = 1 , dtype = router_top_value .dtype )
8284 router_scores = torch .zeros_like (router_logits ).scatter_ (1 , router_indices , router_top_value )
8385 return router_scores , router_indices
84-
85- def _ffn (self , input , infer_state : FlashAttentionStateInfo , layer_weight : GptOssTransformerLayerWeight ) -> torch .Tensor :
86+
87+ def _ffn (
88+ self , input , infer_state : FlashAttentionStateInfo , layer_weight : GptOssTransformerLayerWeight
89+ ) -> torch .Tensor :
8690 router_scores , router_indices = self ._router (input , layer_weight ) # (num_experts, seq_len)
87- routed_out = self ._experts (input , router_indices = router_indices , routing_weights = router_scores , layer_weight = layer_weight )
91+ routed_out = self ._experts (
92+ input , router_indices = router_indices , routing_weights = router_scores , layer_weight = layer_weight
93+ )
8894 return routed_out
89-
90- def _conext_sliding_attention_flashattention (self , q , kv , infer_state : FlashAttentionStateInfo , layer_weight , out = None ):
91- if self .network_config_ ['layer_types' ][self .layer_num_ ] == "sliding_attention" :
92- window_size = (self .sliding_window - 1 , self .sliding_window - 1 )
95+
96+ def _conext_sliding_attention_flashattention (
97+ self , q , kv , infer_state : FlashAttentionStateInfo , layer_weight , out = None
98+ ):
99+ if self .network_config_ ["layer_types" ][self .layer_num_ ] == "sliding_attention" :
100+ window_size = (self .sliding_window - 1 , self .sliding_window - 1 )
93101 else :
94102 window_size = (- 1 , - 1 )
95-
103+
96104 cache_k = infer_state .mem_manager .kv_buffer [self .layer_num_ ][:, 0 : self .tp_k_head_num_ , :].reshape (
97105 - 1 , 1 , self .tp_k_head_num_ , self .head_dim_
98106 )
@@ -114,7 +122,7 @@ def _conext_sliding_attention_flashattention(self, q, kv, infer_state: FlashAtte
114122 max_seqlen_q = infer_state .q_max_seq_len ,
115123 softmax_scale = sm_scale ,
116124 causal = True ,
117- window_size = ( - 1 , - 1 ) ,
125+ window_size = window_size ,
118126 softcap = 0.0 ,
119127 k_descale = k_descale ,
120128 v_descale = v_descale ,
@@ -124,11 +132,11 @@ def _conext_sliding_attention_flashattention(self, q, kv, infer_state: FlashAtte
124132 return o
125133
126134 def _token_sliding_attention_flashattention (self , q , infer_state : FlashAttentionStateInfo , layer_weight , out = None ):
127- if self .network_config_ [' layer_types' ][self .layer_num_ ] == "sliding_attention" :
128- window_size = (self .sliding_window - 1 , self .sliding_window - 1 )
135+ if self .network_config_ [" layer_types" ][self .layer_num_ ] == "sliding_attention" :
136+ window_size = (self .sliding_window - 1 , self .sliding_window - 1 )
129137 else :
130138 window_size = (- 1 , - 1 )
131-
139+
132140 cache_k = infer_state .mem_manager .kv_buffer [self .layer_num_ ][:, 0 : self .tp_k_head_num_ , :].reshape (
133141 - 1 , 1 , self .tp_k_head_num_ , self .head_dim_
134142 )
@@ -157,4 +165,4 @@ def _token_sliding_attention_flashattention(self, q, infer_state: FlashAttention
157165 return_softmax_lse = False ,
158166 sinks = layer_weight .attn_sinks .weight ,
159167 )
160- return o
168+ return o
0 commit comments