Skip to content

Commit c83780f

Browse files
committed
fixpruning when not using score for experts
1 parent 3385820 commit c83780f

File tree

1 file changed

+29
-30
lines changed

1 file changed

+29
-30
lines changed

modelopt/torch/nas/plugins/megatron.py

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -702,9 +702,11 @@ def _drop_experts_during_export(self) -> None:
702702
active_slice = hp.active_slice
703703

704704
# TODO: @ataghibakhsh: Hack sorting here, move to proper place.
705-
importance = hp.importance.argsort(descending=True)
706-
707-
self.local_experts = nn.ModuleList([self.local_experts[i] for i in importance])
705+
importance = hp.importance
706+
if (importance.sum() > 0.0).item():
707+
importance = hp.importance.argsort(descending=True)
708+
709+
self.local_experts = nn.ModuleList([self.local_experts[i] for i in importance])
708710

709711
if isinstance(active_slice, slice):
710712
# No sorting applied, keep first N experts
@@ -721,41 +723,38 @@ def _drop_experts_during_export(self) -> None:
721723

722724
def _track_expert_l2_importance(self, module, input, output):
723725
"""Track expert importance based on L2 norms of expert outputs."""
724-
# Input: (permuted_local_hidden_states, tokens_per_expert, permuted_probs)
725-
# Output: (output_local, output_bias_local)
726726

727-
if len(input) >= 2 and isinstance(output, tuple):
728-
tokens_per_expert = input[1] # tokens_per_expert tensor
729-
output_local = output[0] # output_local tensor
730-
731-
# Convert to float32 for precision
732-
output_local = output_local.to(torch.float32).detach()
727+
tokens_per_expert = input[1] # tokens_per_expert tensor
728+
output_local = output[0] # output_local tensor
729+
730+
# Convert to float32 for precision
731+
output_local = output_local.to(torch.float32).detach()
732+
733+
# Split output back to per-expert outputs using torch.split
734+
tokens_per_expert_list = tokens_per_expert.tolist()
735+
736+
output_local_list = torch.split(output_local, tokens_per_expert_list)
737+
738+
# Compute L2 norm for each expert's output
739+
for expert_idx, expert_output in enumerate(output_local_list):
740+
# Guard: if expert_output is empty tensor, add zero score
741+
if expert_output.numel() == 0:
742+
l2_norm = 0.0
743+
else:
744+
# Compute L2 norm of expert output (router_prob * expert_output)
745+
l2_norm = torch.linalg.vector_norm(expert_output, ord=2, dim=-1).sum().item()
733746

734-
# Split output back to per-expert outputs using torch.split
735-
tokens_per_expert_list = tokens_per_expert.tolist()
736-
if len(tokens_per_expert_list) > 0:
737-
output_local_list = torch.split(output_local, tokens_per_expert_list)
738-
739-
# Compute L2 norm for each expert's output
740-
for expert_idx, expert_output in enumerate(output_local_list):
741-
if expert_idx < len(self._expert_l2_scores):
742-
# Guard: if expert_output is empty tensor, add zero score
743-
if expert_output.numel() == 0:
744-
l2_norm = 0.0
745-
else:
746-
# Compute L2 norm of expert output (router_prob * expert_output)
747-
l2_norm = torch.linalg.vector_norm(expert_output, ord=2).item()
748-
749-
# Accumulate L2 scores and sample counts
750-
self._expert_l2_scores[expert_idx] += l2_norm
751-
self._expert_sample_counts[expert_idx] += 1
747+
# Accumulate L2 scores and sample counts
748+
self._expert_l2_scores[expert_idx] += l2_norm
749+
self._expert_sample_counts[expert_idx] += tokens_per_expert_list[expert_idx]
750+
752751

753752
def _estimate_expert_importance(self) -> TracedHp.Importance:
754753
"""Estimate expert importance based on accumulated L2 norms."""
755754
# Average L2 scores across samples (avoid division by zero)
756755
avg_l2_scores = self._expert_l2_scores / (self._expert_sample_counts + 1e-8)
757756
# Normalize to get importance scores
758-
return avg_l2_scores / (avg_l2_scores.sum() + 1e-8)
757+
return avg_l2_scores
759758

760759
def set_hidden_size_hp(self, hidden_size: TracedHp) -> None:
761760
"""Set hidden size for all expert MLPs."""

0 commit comments

Comments
 (0)