1616"""Plugin to add NAS/Pruning support for megatron-core Language models like GPT and Mamba."""
1717
1818import types
19+ from abc import ABC
1920from collections .abc import Callable , Sequence
2021from typing import Any
2122
5253from megatron .core .transformer .moe .shared_experts import SharedExpertMLP
5354from megatron .core .transformer .transformer_layer import TransformerLayer
5455
56+ from modelopt .torch .nas .modules import DynamicModuleList
5557from modelopt .torch .opt .dynamic import DynamicModule
5658from modelopt .torch .opt .hparam import HPType
5759from modelopt .torch .opt .searcher import ConstraintsDict
58- from modelopt .torch .opt .utils import named_hparams
5960from modelopt .torch .trace import Symbol
6061from modelopt .torch .utils import distributed as dist
6162from modelopt .torch .utils import (
@@ -201,6 +202,8 @@ def _setup(self):
201202 )
202203 if isinstance (self , SharedExpertMLP ):
203204 self .hparam_name = "moe_shared_expert_intermediate_size"
205+ elif self .config .num_moe_experts is not None :
206+ self .hparam_name = "moe_ffn_hidden_size"
204207 else :
205208 self .hparam_name = "ffn_hidden_size"
206209 self .linear_fc1 = DMRegistry .convert (self .linear_fc1 )
@@ -650,8 +653,9 @@ def export(self) -> torch.nn.Module:
650653
651654
652655# MoE DynamicModules ###############################################################################
656+ # Add ABC to avoid TypeError: object layout differs (because parent if TopKRouter inherits from ABC)
653657@DMRegistry .register ({TopKRouter : "megatron.core.transformer.moe.router.TopKRouter" })
654- class _DynamicTopKRouter (DynamicModule ):
658+ class _DynamicTopKRouter (DynamicModule , ABC ):
655659 """A TopKRouter with dynamic hyperparams."""
656660
657661 def _setup (self ):
@@ -660,11 +664,11 @@ def _setup(self):
660664 # Register num_moe_experts hparam name to match TransformerConfig's name.
661665 # Will be overridden by _DynamicSequentialMLP's hp.
662666 self ._register_hparam ("num_moe_experts" , TracedHp (list (range (1 , self .weight .shape [0 ] + 1 ))))
663- self ._register_dynamic_attribute ("num_experts" , lambda mod , val : mod .num_moe_experts )
664667 # Register hidden_size reference (will be overridden by _DynamicMoELayer's hidden_size)
665668 self ._register_hparam ("hidden_size" , TracedHp (list (range (1 , self .weight .shape [1 ] + 1 ))))
666669
667670 # Register dynamic attributes
671+ self ._register_dynamic_attribute ("num_experts" , lambda mod , val : mod .num_moe_experts )
668672 self ._register_dynamic_attribute ("weight" , self ._get_router_weight )
669673 if self .enable_expert_bias :
670674 self ._register_dynamic_attribute ("expert_bias" , self ._get_slice_by_num_moe_experts )
@@ -702,7 +706,9 @@ def _setup(self):
702706 lambda mod , val : mod .num_moe_experts , # EP = 1
703707 )
704708
705- # Convert each individual expert MLP to dynamic
709+ # Convert local_experts list and each individual expert MLP to dynamic modules
710+ self .local_experts = DynamicModuleList .convert (self .local_experts )
711+ self .local_experts .depth = num_moe_experts # Reuse same hparam for depth
706712 for i in range (len (self .local_experts )):
707713 self .local_experts [i ] = DMRegistry .convert (self .local_experts [i ])
708714
@@ -725,6 +731,11 @@ def set_hidden_size_hp(self, hidden_size: TracedHp) -> None:
725731
726732 def _expert_l2_imp_forward_hook (self , module , input , output ):
727733 """Track expert importance based on L2 norms of expert outputs."""
734+ # Dont aggregate activations from non-max subnets (e.g. from profiling)
735+ num_moe_experts = self .get_hparam ("num_moe_experts" )
736+ if num_moe_experts .active != num_moe_experts .max :
737+ return
738+
728739 # Split output back to per-expert outputs using torch.split
729740 tokens_per_expert_list = input [1 ].tolist ()
730741 # use full precision to avoid overflow
@@ -757,26 +768,11 @@ def _estimate_expert_importance(self) -> TracedHp.Importance:
757768 self ._activations ["expert_sample_counts" ] + 1e-8
758769 )
759770
760- def _export_drop_experts (self ) -> None :
761- """Drop experts during export based on active hyperparameter value."""
762- # Get sorted + trimmed order of experts to keep
763- active_slice = self .get_hparam ("num_moe_experts" ).active_slice
764-
765- # Trim experts based on active hparam value
766- if isinstance (active_slice , slice ):
767- kept_experts = self .local_experts [: active_slice .stop ]
768- else :
769- kept_experts = [self .local_experts [i ] for i in active_slice ]
770-
771- # Replace the ModuleList with pruned experts
772- self .local_experts = nn .ModuleList (kept_experts )
773-
774771 def export (self ) -> torch .nn .Module :
775772 """Export the dynamic module to a standard SequentialMLP."""
776773 self .hook_handle .remove ()
777774
778- # Drop experts based on active hparam value and export remaining experts
779- self ._export_drop_experts ()
775+ self .local_experts .export ()
780776 for expert in self .local_experts :
781777 expert .export ()
782778
@@ -789,9 +785,6 @@ class _DynamicMoELayer(DynamicModule):
789785 """A MoELayer with dynamic hyperparams."""
790786
791787 def _setup (self ):
792- # TODO: Add DynamicTokenDispatcher for moe_shared_expert_overlap support
793- assert not self .shared_expert_overlap , "moe_shared_expert_overlap is not supported yet!"
794-
795788 # Convert to dynamic modules
796789 # Reuse _DynamicSequentialMLP's num_moe_experts hparam for _DynamicTopKRouter's hparam so
797790 # importance estimator is not lost.
@@ -810,7 +803,11 @@ def _setup(self):
810803
811804 def _get_local_expert_indices (self , mod : "_DynamicMoELayer" , val : list [int ]) -> list [int ]:
812805 """Get local expert indices for the current active hparam value."""
813- return list (range (mod .num_local_experts ))
806+ active_slice = self .experts .get_hparam ("num_moe_experts" ).active_slice
807+ if isinstance (active_slice , slice ):
808+ return list (range (active_slice .stop ))
809+ else :
810+ return active_slice .tolist ()
814811
815812 def set_hidden_size_hp (self , hidden_size : TracedHp ) -> None :
816813 """Set hidden size for all MoE components from global hidden_size hparam."""
@@ -955,17 +952,6 @@ def export(self):
955952 super ().export ()
956953 return self
957954
958- def freeze (self ):
959- """Freeze the dynamic module."""
960- super ().freeze ()
961- if isinstance (self .self_attention , SelfAttention ):
962- self .input_layernorm .freeze ()
963- self .self_attention .freeze ()
964-
965- if isinstance (self .mlp , (MLP , MoELayer )):
966- self .pre_mlp_layernorm .freeze ()
967- self .mlp .freeze ()
968-
969955
970956# Mamba DynamicModules #############################################################################
971957class MambaNumHeadsHp (TracedHp ):
@@ -1356,11 +1342,6 @@ def export(self):
13561342 super ().export ()
13571343 return self
13581344
1359- def freeze (self ):
1360- """Freeze the hyperparameters."""
1361- self .mixer .freeze ()
1362- super ().freeze ()
1363-
13641345
13651346if HAS_MAMBA :
13661347 DMRegistry .register ({ExtendedRMSNorm : "megatron.core.ssm.mamba_mixer.ExtendedRMSNorm" })(
@@ -1559,12 +1540,6 @@ def _export_drop_layers(self) -> None:
15591540
15601541 def export (self ) -> torch .nn .Module :
15611542 """Export the dynamic module to a torch.nn.Module."""
1562- # TODO: Improve this!
1563- # Slice order needs to be reset before exporting since weights are already
1564- # force assigned and we dont want to sort them again (losing the correct order)
1565- for n , hp in named_hparams (self , configurable = True ):
1566- hp .enforce_order (None )
1567-
15681543 for handle in self .hook_handles :
15691544 handle .remove ()
15701545 self ._export_drop_layers ()
@@ -1578,12 +1553,6 @@ def export(self) -> torch.nn.Module:
15781553 super ().export ()
15791554 return self
15801555
1581- def freeze (self ) -> None :
1582- """Freeze the dynamic module."""
1583- super ().freeze ()
1584- for layer in self .decoder .layers :
1585- layer .freeze ()
1586-
15871556 def get_activations_and_layer_scores (
15881557 self ,
15891558 ) -> tuple [list [dict [str , torch .Tensor ]], dict [int , torch .Tensor ]]:
0 commit comments