1616"""Plugin to add NAS/Pruning support for megatron-core Language models like GPT and Mamba."""
1717
1818import types
19+ from abc import ABC
1920from collections .abc import Callable , Sequence
2021from typing import Any
2122
5253from megatron .core .transformer .moe .shared_experts import SharedExpertMLP
5354from megatron .core .transformer .transformer_layer import TransformerLayer
5455
56+ from modelopt .torch .nas .modules import DynamicModuleList
5557from modelopt .torch .opt .dynamic import DynamicModule
5658from modelopt .torch .opt .hparam import HPType
5759from modelopt .torch .opt .searcher import ConstraintsDict
58- from modelopt .torch .opt .utils import named_hparams
5960from modelopt .torch .trace import Symbol
6061from modelopt .torch .utils import distributed as dist
6162from modelopt .torch .utils import (
@@ -201,6 +202,8 @@ def _setup(self):
201202 )
202203 if isinstance (self , SharedExpertMLP ):
203204 self .hparam_name = "moe_shared_expert_intermediate_size"
205+ elif self .config .num_moe_experts is not None :
206+ self .hparam_name = "moe_ffn_hidden_size"
204207 else :
205208 self .hparam_name = "ffn_hidden_size"
206209 self .linear_fc1 = DMRegistry .convert (self .linear_fc1 )
@@ -274,8 +277,7 @@ def export(self) -> torch.nn.Module:
274277 self .hook_handle .remove ()
275278 self .linear_fc1 .export ()
276279 self .linear_fc2 .export ()
277- super ().export ()
278- return self
280+ return super ().export ()
279281
280282
281283# SelfAttention DynamicModules #####################################################################
@@ -645,42 +647,37 @@ def export(self) -> torch.nn.Module:
645647 self .core_attention .export ()
646648 self .linear_qkv .export ()
647649 self .linear_proj .export ()
648- super ().export ()
649- return self
650+ return super ().export ()
650651
651652
652653# MoE DynamicModules ###############################################################################
654+ # Add ABC to avoid TypeError: object layout differs (because parent if TopKRouter inherits from ABC)
653655@DMRegistry .register ({TopKRouter : "megatron.core.transformer.moe.router.TopKRouter" })
654- class _DynamicTopKRouter (DynamicModule ):
656+ class _DynamicTopKRouter (ABC , DynamicModule ):
655657 """A TopKRouter with dynamic hyperparams."""
656658
657659 def _setup (self ):
658- # Register hparams for router weight dimensions
660+ # Register hparams for router weight dimensions (will be overridden by _DynamicSequentialMLP's hp)
659661 # Router weight shape: [num_moe_experts, hidden_size]
660- # Register num_moe_experts hparam name to match TransformerConfig's name.
661- # Will be overridden by _DynamicSequentialMLP's hp.
662- self ._register_hparam ("num_moe_experts" , TracedHp (list (range (1 , self .weight .shape [0 ] + 1 ))))
663- self ._register_dynamic_attribute ("num_experts" , lambda mod , val : mod .num_moe_experts )
662+ self ._register_hparam ("num_experts" , TracedHp (list (range (1 , self .weight .shape [0 ] + 1 ))))
664663 # Register hidden_size reference (will be overridden by _DynamicMoELayer's hidden_size)
665664 self ._register_hparam ("hidden_size" , TracedHp (list (range (1 , self .weight .shape [1 ] + 1 ))))
666665
667666 # Register dynamic attributes
668667 self ._register_dynamic_attribute ("weight" , self ._get_router_weight )
669668 if self .enable_expert_bias :
670- self ._register_dynamic_attribute ("expert_bias" , self ._get_slice_by_num_moe_experts )
669+ self ._register_dynamic_attribute ("expert_bias" , self ._get_slice_by_num_experts )
671670 self ._register_dynamic_attribute (
672- "local_tokens_per_expert" , self ._get_slice_by_num_moe_experts
671+ "local_tokens_per_expert" , self ._get_slice_by_num_experts
673672 )
674673
675674 @staticmethod
676675 def _get_router_weight (mod : "_DynamicTopKRouter" , weight : torch .Tensor ) -> torch .Tensor :
677- return get_sliced_tensor (mod , weight , "num_moe_experts " , "hidden_size" )
676+ return get_sliced_tensor (mod , weight , "num_experts " , "hidden_size" )
678677
679678 @staticmethod
680- def _get_slice_by_num_moe_experts (
681- mod : "_DynamicTopKRouter" , bias : torch .Tensor
682- ) -> torch .Tensor :
683- return get_sliced_tensor (mod , bias , "num_moe_experts" )
679+ def _get_slice_by_num_experts (mod : "_DynamicTopKRouter" , bias : torch .Tensor ) -> torch .Tensor :
680+ return get_sliced_tensor (mod , bias , "num_experts" )
684681
685682 def set_hidden_size_hp (self , hidden_size : TracedHp ) -> None :
686683 """Set hidden_size hparam for router weights from global hidden_size hparam."""
@@ -692,17 +689,13 @@ class _DynamicSequentialMLP(DynamicModule):
692689 """A SequentialMLP with dynamic hyperparams."""
693690
694691 def _setup (self ):
695- # Register hparam for number of active experts
696- # Use num_moe_experts hparam name to match TransformerConfig's name
697- # Will be shared with _DynamicTopKRouter's hp.
692+ # Register hparam for number of active experts (will be shared with _DynamicTopKRouter's hp)
698693 num_moe_experts = TracedHp (list (range (1 , self .num_local_experts + 1 )))
699- self ._register_hparam ("num_moe_experts" , num_moe_experts )
700- self ._register_dynamic_attribute (
701- "num_local_experts" ,
702- lambda mod , val : mod .num_moe_experts , # EP = 1
703- )
694+ self ._register_hparam ("num_local_experts" , num_moe_experts )
704695
705- # Convert each individual expert MLP to dynamic
696+ # Convert local_experts list and each individual expert MLP to dynamic modules
697+ self .local_experts = DynamicModuleList .convert (self .local_experts )
698+ self .local_experts .depth = num_moe_experts # Reuse same hparam for depth
706699 for i in range (len (self .local_experts )):
707700 self .local_experts [i ] = DMRegistry .convert (self .local_experts [i ])
708701
@@ -725,6 +718,11 @@ def set_hidden_size_hp(self, hidden_size: TracedHp) -> None:
725718
726719 def _expert_l2_imp_forward_hook (self , module , input , output ):
727720 """Track expert importance based on L2 norms of expert outputs."""
721+ # Dont aggregate activations from non-max subnets (e.g. from profiling)
722+ num_moe_experts = self .get_hparam ("num_local_experts" )
723+ if num_moe_experts .active != num_moe_experts .max :
724+ return
725+
728726 # Split output back to per-expert outputs using torch.split
729727 tokens_per_expert_list = input [1 ].tolist ()
730728 # use full precision to avoid overflow
@@ -757,60 +755,50 @@ def _estimate_expert_importance(self) -> TracedHp.Importance:
757755 self ._activations ["expert_sample_counts" ] + 1e-8
758756 )
759757
760- def _export_drop_experts (self ) -> None :
761- """Drop experts during export based on active hyperparameter value."""
762- # Get sorted + trimmed order of experts to keep
763- active_slice = self .get_hparam ("num_moe_experts" ).active_slice
764-
765- # Trim experts based on active hparam value
766- if isinstance (active_slice , slice ):
767- kept_experts = self .local_experts [: active_slice .stop ]
768- else :
769- kept_experts = [self .local_experts [i ] for i in active_slice ]
770-
771- # Replace the ModuleList with pruned experts
772- self .local_experts = nn .ModuleList (kept_experts )
773-
774758 def export (self ) -> torch .nn .Module :
775759 """Export the dynamic module to a standard SequentialMLP."""
776760 self .hook_handle .remove ()
777-
778- # Drop experts based on active hparam value and export remaining experts
779- self ._export_drop_experts ()
780761 for expert in self .local_experts :
781762 expert .export ()
782-
783- super ().export ()
784- return self
763+ self .local_experts .export ()
764+ return super ().export ()
785765
786766
787767@DMRegistry .register ({MoELayer : "megatron.core.transformer.moe.moe_layer.MoELayer" })
788768class _DynamicMoELayer (DynamicModule ):
789769 """A MoELayer with dynamic hyperparams."""
790770
791771 def _setup (self ):
792- # TODO: Add DynamicTokenDispatcher for moe_shared_expert_overlap support
793- assert not self .shared_expert_overlap , "moe_shared_expert_overlap is not supported yet!"
794-
795772 # Convert to dynamic modules
796773 # Reuse _DynamicSequentialMLP's num_moe_experts hparam for _DynamicTopKRouter's hparam so
797- # importance estimator is not lost .
774+ # importance estimator and depth hparam is retained .
798775 self .router = DMRegistry .convert (self .router )
799776 self .experts = DMRegistry .convert (self .experts )
800- num_moe_experts_hp = self .experts .get_hparam ("num_moe_experts" )
777+ num_moe_experts_hp = self .experts .get_hparam ("num_local_experts" )
778+
779+ # NOTE: Use num_moe_experts hparam name in top-level module to match TransformerConfig's name
780+ self ._register_hparam ("num_moe_experts" , num_moe_experts_hp )
801781 self ._register_dynamic_attribute (
802782 "num_local_experts" ,
803783 lambda mod , val : num_moe_experts_hp .active , # EP = 1
804784 )
805- self .router .num_moe_experts = num_moe_experts_hp
785+ self .router .num_experts = num_moe_experts_hp
806786 if self .use_shared_expert :
807787 self .shared_experts = DMRegistry .convert (self .shared_experts )
808788
809789 self ._register_dynamic_attribute ("local_expert_indices" , self ._get_local_expert_indices )
790+ # assert self.config.moe_token_dispatcher_type == "alltoall", (
791+ # "Only moe_token_dispatcher_type=='alltoall' is supported!"
792+ # )
793+ # self.token_dispatcher = DMRegistry.convert(self.token_dispatcher)
810794
811795 def _get_local_expert_indices (self , mod : "_DynamicMoELayer" , val : list [int ]) -> list [int ]:
812796 """Get local expert indices for the current active hparam value."""
813- return list (range (mod .num_local_experts ))
797+ active_slice = self .get_hparam ("num_moe_experts" ).active_slice
798+ if isinstance (active_slice , slice ):
799+ return list (range (active_slice .stop ))
800+ else :
801+ return active_slice .tolist ()
814802
815803 def set_hidden_size_hp (self , hidden_size : TracedHp ) -> None :
816804 """Set hidden size for all MoE components from global hidden_size hparam."""
@@ -824,7 +812,7 @@ def modify(
824812 ):
825813 """Modify MoE hparam choices based on search space config."""
826814 # Modify num_moe_experts hparam choices (applies to both router and experts)
827- expert_hp = self .experts . get_hparam ("num_moe_experts" )
815+ expert_hp = self .get_hparam ("num_moe_experts" )
828816 choices = {int (make_divisible (c , num_moe_experts_divisor )) for c in expert_hp .choices } # type: ignore[arg-type]
829817 expert_hp .choices = list (set (expert_hp .choices ) & choices | {expert_hp .original })
830818
@@ -840,8 +828,7 @@ def export(self) -> torch.nn.Module:
840828 self .experts .export ()
841829 if self .use_shared_expert :
842830 self .shared_experts .export ()
843- super ().export ()
844- return self
831+ return super ().export ()
845832
846833
847834# TransformerLayer DynamicModule ###################################################################
@@ -952,19 +939,7 @@ def export(self):
952939 if isinstance (self .mlp , (MLP , MoELayer )):
953940 self .pre_mlp_layernorm .export ()
954941 self .mlp .export ()
955- super ().export ()
956- return self
957-
958- def freeze (self ):
959- """Freeze the dynamic module."""
960- super ().freeze ()
961- if isinstance (self .self_attention , SelfAttention ):
962- self .input_layernorm .freeze ()
963- self .self_attention .freeze ()
964-
965- if isinstance (self .mlp , (MLP , MoELayer )):
966- self .pre_mlp_layernorm .freeze ()
967- self .mlp .freeze ()
942+ return super ().export ()
968943
969944
970945# Mamba DynamicModules #############################################################################
@@ -1307,8 +1282,7 @@ def export(self) -> torch.nn.Module:
13071282 self .conv1d .export ()
13081283 if self .rmsnorm :
13091284 self .norm .export ()
1310- super ().export ()
1311- return self
1285+ return super ().export ()
13121286
13131287
13141288class _DynamicMambaLayer (DynamicModule , MambaTransformerLayerMixin ):
@@ -1353,13 +1327,7 @@ def export(self):
13531327 self ._export_mixin ()
13541328 self .mixer .export ()
13551329 self .norm .export ()
1356- super ().export ()
1357- return self
1358-
1359- def freeze (self ):
1360- """Freeze the hyperparameters."""
1361- self .mixer .freeze ()
1362- super ().freeze ()
1330+ return super ().export ()
13631331
13641332
13651333if HAS_MAMBA :
@@ -1559,12 +1527,6 @@ def _export_drop_layers(self) -> None:
15591527
15601528 def export (self ) -> torch .nn .Module :
15611529 """Export the dynamic module to a torch.nn.Module."""
1562- # TODO: Improve this!
1563- # Slice order needs to be reset before exporting since weights are already
1564- # force assigned and we dont want to sort them again (losing the correct order)
1565- for n , hp in named_hparams (self , configurable = True ):
1566- hp .enforce_order (None )
1567-
15681530 for handle in self .hook_handles :
15691531 handle .remove ()
15701532 self ._export_drop_layers ()
@@ -1575,14 +1537,7 @@ def export(self) -> torch.nn.Module:
15751537 if is_pipeline_last_stage ():
15761538 getattr (self .decoder , self .final_norm_attr_name ).export ()
15771539 self .output_layer .export ()
1578- super ().export ()
1579- return self
1580-
1581- def freeze (self ) -> None :
1582- """Freeze the dynamic module."""
1583- super ().freeze ()
1584- for layer in self .decoder .layers :
1585- layer .freeze ()
1540+ return super ().export ()
15861541
15871542 def get_activations_and_layer_scores (
15881543 self ,
0 commit comments