@@ -73,7 +73,9 @@ def reset_parameters(self) -> None:
7373
7474 def _init_weights (self , module : nn .Module ) -> None :
7575 """Meant to be used with `gpt.apply(gpt._init_weights)`."""
76- if isinstance (module , nn .Linear ):
76+ if isinstance (module , GroupedTopkRouter ):
77+ torch .nn .init .normal_ (module .weight .data , mean = 0.0 , std = 0.02 )
78+ elif isinstance (module , nn .Linear ):
7779 torch .nn .init .normal_ (module .weight , mean = 0.0 , std = 0.02 )
7880 if module .bias is not None :
7981 torch .nn .init .zeros_ (module .bias )
@@ -286,6 +288,8 @@ def __init__(
286288 else (None if config .shared_attention_norm else config .norm_class (config .n_embd , eps = config .norm_eps ))
287289 )
288290 self .mlp = config .mlp_class (config )
291+ if config .first_k_dense_replace is not None and block_idx < config .first_k_dense_replace :
292+ self .mlp = LLaMAMLP (config )
289293 self .post_mlp_norm = (
290294 config .norm_class (config .n_embd , eps = config .norm_eps ) if config .post_mlp_norm else nn .Identity ()
291295 )
@@ -734,10 +738,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
734738class LLaMAMoE (nn .Module ):
735739 def __init__ (self , config : Config ) -> None :
736740 super ().__init__ ()
737- self .gate = nn .Linear (config .n_embd , config .n_expert , bias = False )
741+ self .gate = (
742+ nn .Linear (config .n_embd , config .n_expert , bias = False )
743+ if not config .n_expert_groups
744+ else GroupedTopkRouter (config )
745+ )
738746 self .experts = nn .ModuleList (
739747 LLaMAMLP (config , intermediate_size = config .moe_intermediate_size ) for _ in range (config .n_expert )
740748 )
749+ if config .n_shared_expert :
750+ self .shared_experts = LLaMAMLP (
751+ config , intermediate_size = config .moe_intermediate_size * config .n_shared_expert
752+ )
741753 self .config = config
742754
743755 def forward (self , x : torch .Tensor ) -> torch .Tensor :
@@ -746,17 +758,71 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
746758 See also figure 1 in https://arxiv.org/abs/2211.15841
747759 """
748760 B , T , C = x .size () # batch size, sequence length, embedding dimensionality (n_embd)
761+ residual_x = x .clone ()
749762 x = x .view (- 1 , C ) # (B*T, C)
750- router = self .gate (x ) # (B*T, n_expert)
751- probs , indices = torch .topk (router , self .config .n_expert_per_token ) # (B*T, n_expert_per_token)
752- probs = probs .softmax (dim = 1 , dtype = torch .float ).to (dtype = x .dtype )
763+ if not self .config .n_expert_groups :
764+ router = self .gate (x ) # (B*T, n_expert)
765+ probs , indices = torch .topk (router , self .config .n_expert_per_token ) # (B*T, n_expert_per_token)
766+ probs = probs .softmax (dim = 1 , dtype = torch .float ).to (dtype = x .dtype )
767+ else :
768+ probs , indices = self .gate (x )
769+ if self .config .routed_scaling_factor != 1.0 :
770+ probs = probs * self .config .routed_scaling_factor
753771 masks = indices .unsqueeze (- 1 ) == torch .arange (self .config .n_expert , device = x .device )
754772 masks = masks .permute (2 , 0 , 1 ) # (n_expert, B*T, n_expert_per_token)
755773 y = torch .zeros_like (x ) # (B*T, C)
756774 for mask , expert in zip (masks , self .experts ):
757775 token_idx , expert_idx = torch .where (mask )
758776 y [token_idx ] += probs [token_idx , expert_idx , None ] * expert (x [token_idx ])
759- return y .view (B , T , C )
777+
778+ y = y .view (B , T , C )
779+ if self .config .n_shared_expert :
780+ y = y + self .shared_experts (residual_x )
781+ return y
782+
783+
784+ class GroupedTopkRouter (nn .Module ):
785+ """
786+ Derived from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py.
787+ DeepseekV3TopkRouter class.
788+ """
789+
790+ def __init__ (self , config : Config ) -> None :
791+ super ().__init__ ()
792+ self .config = config
793+ self .weight = nn .Parameter (torch .empty (config .n_expert , config .n_embd ))
794+ self .register_buffer ("e_score_correction_bias" , torch .zeros (config .n_expert ))
795+
796+ @torch .no_grad ()
797+ def get_topk_indices (self , scores : torch .Tensor ) -> torch .Tensor :
798+ scores_for_choice = scores .view (- 1 , self .config .n_expert ) + self .e_score_correction_bias .unsqueeze (0 )
799+ group_scores = (
800+ scores_for_choice .view (- 1 , self .config .n_expert_groups , self .config .n_expert // self .config .n_expert_groups )
801+ .topk (self .config .n_topk_scores_per_group , dim = - 1 )[0 ] # Top k scores for each group
802+ .sum (dim = - 1 )
803+ )
804+
805+ group_idx = torch .topk (group_scores , k = self .config .n_topk_groups , dim = - 1 , sorted = False )[1 ]
806+ group_mask = torch .zeros_like (group_scores )
807+ group_mask .scatter_ (1 , group_idx , 1 )
808+ score_mask = (
809+ group_mask .unsqueeze (- 1 )
810+ .expand (- 1 , self .config .n_expert_groups , self .config .n_expert // self .config .n_expert_groups )
811+ .reshape (- 1 , self .config .n_expert )
812+ )
813+ scores_for_choice = scores_for_choice .masked_fill (~ score_mask .bool (), 0.0 )
814+ topk_indices = torch .topk (scores_for_choice , k = self .config .n_expert_per_token , dim = - 1 , sorted = False )[1 ]
815+ return topk_indices
816+
817+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
818+ router_logits = F .linear (x .type (torch .float32 ), self .weight .type (torch .float32 ))
819+ scores = router_logits .sigmoid ()
820+ topk_indices = self .get_topk_indices (scores )
821+ topk_weights = scores .gather (1 , topk_indices )
822+ if self .config .norm_topk_prob :
823+ denominator = topk_weights .sum (dim = - 1 , keepdim = True ) + 1e-20
824+ topk_weights /= denominator
825+ return topk_weights , topk_indices
760826
761827
762828def build_rope_cache (
0 commit comments