@@ -702,9 +702,11 @@ def _drop_experts_during_export(self) -> None:
702702 active_slice = hp .active_slice
703703
704704 # TODO: @ataghibakhsh: Hack sorting here, move to proper place.
705- importance = hp .importance .argsort (descending = True )
706-
707- self .local_experts = nn .ModuleList ([self .local_experts [i ] for i in importance ])
705+ importance = hp .importance
706+ if (importance .sum () > 0.0 ).item ():
707+ importance = hp .importance .argsort (descending = True )
708+
709+ self .local_experts = nn .ModuleList ([self .local_experts [i ] for i in importance ])
708710
709711 if isinstance (active_slice , slice ):
710712 # No sorting applied, keep first N experts
@@ -721,41 +723,38 @@ def _drop_experts_during_export(self) -> None:
721723
722724 def _track_expert_l2_importance (self , module , input , output ):
723725 """Track expert importance based on L2 norms of expert outputs."""
724- # Input: (permuted_local_hidden_states, tokens_per_expert, permuted_probs)
725- # Output: (output_local, output_bias_local)
726726
727- if len (input ) >= 2 and isinstance (output , tuple ):
728- tokens_per_expert = input [1 ] # tokens_per_expert tensor
729- output_local = output [0 ] # output_local tensor
730-
731- # Convert to float32 for precision
732- output_local = output_local .to (torch .float32 ).detach ()
727+ tokens_per_expert = input [1 ] # tokens_per_expert tensor
728+ output_local = output [0 ] # output_local tensor
729+
730+ # Convert to float32 for precision
731+ output_local = output_local .to (torch .float32 ).detach ()
732+
733+ # Split output back to per-expert outputs using torch.split
734+ tokens_per_expert_list = tokens_per_expert .tolist ()
735+
736+ output_local_list = torch .split (output_local , tokens_per_expert_list )
737+
738+ # Compute L2 norm for each expert's output
739+ for expert_idx , expert_output in enumerate (output_local_list ):
740+ # Guard: if expert_output is empty tensor, add zero score
741+ if expert_output .numel () == 0 :
742+ l2_norm = 0.0
743+ else :
744+ # Compute L2 norm of expert output (router_prob * expert_output)
745+ l2_norm = torch .linalg .vector_norm (expert_output , ord = 2 , dim = - 1 ).sum ().item ()
733746
734- # Split output back to per-expert outputs using torch.split
735- tokens_per_expert_list = tokens_per_expert .tolist ()
736- if len (tokens_per_expert_list ) > 0 :
737- output_local_list = torch .split (output_local , tokens_per_expert_list )
738-
739- # Compute L2 norm for each expert's output
740- for expert_idx , expert_output in enumerate (output_local_list ):
741- if expert_idx < len (self ._expert_l2_scores ):
742- # Guard: if expert_output is empty tensor, add zero score
743- if expert_output .numel () == 0 :
744- l2_norm = 0.0
745- else :
746- # Compute L2 norm of expert output (router_prob * expert_output)
747- l2_norm = torch .linalg .vector_norm (expert_output , ord = 2 ).item ()
748-
749- # Accumulate L2 scores and sample counts
750- self ._expert_l2_scores [expert_idx ] += l2_norm
751- self ._expert_sample_counts [expert_idx ] += 1
747+ # Accumulate L2 scores and sample counts
748+ self ._expert_l2_scores [expert_idx ] += l2_norm
749+ self ._expert_sample_counts [expert_idx ] += tokens_per_expert_list [expert_idx ]
750+
752751
753752 def _estimate_expert_importance (self ) -> TracedHp .Importance :
754753 """Estimate expert importance based on accumulated L2 norms."""
755754 # Average L2 scores across samples (avoid division by zero)
756755 avg_l2_scores = self ._expert_l2_scores / (self ._expert_sample_counts + 1e-8 )
757756 # Normalize to get importance scores
758- return avg_l2_scores / ( avg_l2_scores . sum () + 1e-8 )
757+ return avg_l2_scores
759758
760759 def set_hidden_size_hp (self , hidden_size : TracedHp ) -> None :
761760 """Set hidden size for all expert MLPs."""
0 commit comments