Skip to content

Commit 0794a1e

Browse files
DynamicModuleList, moe_ffn hparam, remove force_assign
Signed-off-by: Keval Morabia <[email protected]>
1 parent 050e1a5 commit 0794a1e

File tree

11 files changed

+224
-151
lines changed

11 files changed

+224
-151
lines changed

examples/megatron-lm/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ Pruning is supported for GPT and Mamba models in Pipeline Parallel mode. Availab
121121
- `TARGET_MAMBA_NUM_HEADS`
122122
- `TARGET_MAMBA_HEAD_DIM`
123123
- `TARGET_NUM_MOE_EXPERTS`
124+
- `TARGET_MOE_FFN_HIDDEN_SIZE`
124125
- `TARGET_MOE_SHARED_EXPERT_INTERMEDIATE_SIZE`
125126
- `TARGET_NUM_LAYERS`
126127
- `LAYERS_TO_DROP` (comma separated, 1-indexed list of layer numbers to directly drop)

examples/pruning/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ If your model parameters are already sorted, you can skip the sorting step by se
8989

9090
| **Algorithm** | **Model** | **Pruning Constraints** |
9191
| :---: | :---: | :---: |
92-
| Minitron | Megatron-core / NeMo based GPT / Mamba / MoE / Hybrid Models<sup>1</sup> | Export config with width (`hidden_size`, `ffn_hidden_size`, `num_attention_heads`, `num_query_groups`, `mamba_num_heads`, `mamba_head_dim`, `num_moe_experts`, `moe_shared_expert_intermediate_size`) and/or depth (`num_layers`) values |
92+
| Minitron | Megatron-core / NeMo based GPT / Mamba / MoE / Hybrid Models<sup>1</sup> | Export config with width (`hidden_size`, `ffn_hidden_size`, `num_attention_heads`, `num_query_groups`, `mamba_num_heads`, `mamba_head_dim`, `num_moe_experts`, `moe_ffn_hidden_size`, `moe_shared_expert_intermediate_size`) and/or depth (`num_layers`) values |
9393
| FastNAS | Computer Vision models | flops, parameters |
9494
| GradNAS | HuggingFace BERT, GPT-J | flops, parameters |
9595

modelopt/torch/nas/modules/container.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from ..registry import DMRegistry
2727
from ..traced_hp import TracedHp
2828

29-
__all__ = ["_DynamicSequential"]
29+
__all__ = ["DynamicModuleList", "_DynamicSequential"]
3030

3131

3232
def _activate_depth(func: Callable) -> Callable:
@@ -97,3 +97,35 @@ def modify(self, *, min_depth: int = 0):
9797
"""
9898
hp = self.get_hparam("depth")
9999
hp.choices = [d for d in hp.choices if d >= min_depth]
100+
101+
102+
# NOTE: We provide a parent class since we do not register to DMRegistry and explicitly convert a module if needed.
103+
class DynamicModuleList(DynamicModule, nn.ModuleList):
104+
"""An ``nn.ModuleList`` container with dynamic hyperparams and variable ``depth``.
105+
106+
Unlike _DynamicSequential, this module supports sorting/reordering of modules based on
107+
importance in addition to variable depth.
108+
"""
109+
110+
def _setup(self):
111+
# register hyperparameters
112+
self._register_hparam("depth", TracedHp(list(range(1, len(self) + 1))))
113+
114+
# register _modules as a dynamic attribute
115+
self._register_dynamic_attribute("_modules", self._get_modules)
116+
117+
@staticmethod
118+
def _get_modules(mod: "DynamicModuleList", modules: dict) -> dict:
119+
"""Get modules with dynamic depth and ordering applied based on active_slice."""
120+
hp = mod.get_hparam("depth")
121+
active_slice = hp.active_slice
122+
123+
items = list(modules.items())
124+
125+
if isinstance(active_slice, slice):
126+
active_items = items[active_slice]
127+
else:
128+
active_items = [items[idx] for idx in active_slice.tolist()]
129+
130+
# Re-create dict with keys as str(index) from 0 to len(active_items)
131+
return {str(i): module for i, (_, module) in enumerate(active_items)}

modelopt/torch/nas/plugins/megatron.py

Lines changed: 48 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"""Plugin to add NAS/Pruning support for megatron-core Language models like GPT and Mamba."""
1717

1818
import types
19+
from abc import ABC
1920
from collections.abc import Callable, Sequence
2021
from typing import Any
2122

@@ -52,10 +53,10 @@
5253
from megatron.core.transformer.moe.shared_experts import SharedExpertMLP
5354
from megatron.core.transformer.transformer_layer import TransformerLayer
5455

56+
from modelopt.torch.nas.modules import DynamicModuleList
5557
from modelopt.torch.opt.dynamic import DynamicModule
5658
from modelopt.torch.opt.hparam import HPType
5759
from modelopt.torch.opt.searcher import ConstraintsDict
58-
from modelopt.torch.opt.utils import named_hparams
5960
from modelopt.torch.trace import Symbol
6061
from modelopt.torch.utils import distributed as dist
6162
from modelopt.torch.utils import (
@@ -201,6 +202,8 @@ def _setup(self):
201202
)
202203
if isinstance(self, SharedExpertMLP):
203204
self.hparam_name = "moe_shared_expert_intermediate_size"
205+
elif self.config.num_moe_experts is not None:
206+
self.hparam_name = "moe_ffn_hidden_size"
204207
else:
205208
self.hparam_name = "ffn_hidden_size"
206209
self.linear_fc1 = DMRegistry.convert(self.linear_fc1)
@@ -274,8 +277,7 @@ def export(self) -> torch.nn.Module:
274277
self.hook_handle.remove()
275278
self.linear_fc1.export()
276279
self.linear_fc2.export()
277-
super().export()
278-
return self
280+
return super().export()
279281

280282

281283
# SelfAttention DynamicModules #####################################################################
@@ -645,42 +647,37 @@ def export(self) -> torch.nn.Module:
645647
self.core_attention.export()
646648
self.linear_qkv.export()
647649
self.linear_proj.export()
648-
super().export()
649-
return self
650+
return super().export()
650651

651652

652653
# MoE DynamicModules ###############################################################################
654+
# Add ABC to avoid TypeError: object layout differs (because parent if TopKRouter inherits from ABC)
653655
@DMRegistry.register({TopKRouter: "megatron.core.transformer.moe.router.TopKRouter"})
654-
class _DynamicTopKRouter(DynamicModule):
656+
class _DynamicTopKRouter(ABC, DynamicModule):
655657
"""A TopKRouter with dynamic hyperparams."""
656658

657659
def _setup(self):
658-
# Register hparams for router weight dimensions
660+
# Register hparams for router weight dimensions (will be overridden by _DynamicSequentialMLP's hp)
659661
# Router weight shape: [num_moe_experts, hidden_size]
660-
# Register num_moe_experts hparam name to match TransformerConfig's name.
661-
# Will be overridden by _DynamicSequentialMLP's hp.
662-
self._register_hparam("num_moe_experts", TracedHp(list(range(1, self.weight.shape[0] + 1))))
663-
self._register_dynamic_attribute("num_experts", lambda mod, val: mod.num_moe_experts)
662+
self._register_hparam("num_experts", TracedHp(list(range(1, self.weight.shape[0] + 1))))
664663
# Register hidden_size reference (will be overridden by _DynamicMoELayer's hidden_size)
665664
self._register_hparam("hidden_size", TracedHp(list(range(1, self.weight.shape[1] + 1))))
666665

667666
# Register dynamic attributes
668667
self._register_dynamic_attribute("weight", self._get_router_weight)
669668
if self.enable_expert_bias:
670-
self._register_dynamic_attribute("expert_bias", self._get_slice_by_num_moe_experts)
669+
self._register_dynamic_attribute("expert_bias", self._get_slice_by_num_experts)
671670
self._register_dynamic_attribute(
672-
"local_tokens_per_expert", self._get_slice_by_num_moe_experts
671+
"local_tokens_per_expert", self._get_slice_by_num_experts
673672
)
674673

675674
@staticmethod
676675
def _get_router_weight(mod: "_DynamicTopKRouter", weight: torch.Tensor) -> torch.Tensor:
677-
return get_sliced_tensor(mod, weight, "num_moe_experts", "hidden_size")
676+
return get_sliced_tensor(mod, weight, "num_experts", "hidden_size")
678677

679678
@staticmethod
680-
def _get_slice_by_num_moe_experts(
681-
mod: "_DynamicTopKRouter", bias: torch.Tensor
682-
) -> torch.Tensor:
683-
return get_sliced_tensor(mod, bias, "num_moe_experts")
679+
def _get_slice_by_num_experts(mod: "_DynamicTopKRouter", bias: torch.Tensor) -> torch.Tensor:
680+
return get_sliced_tensor(mod, bias, "num_experts")
684681

685682
def set_hidden_size_hp(self, hidden_size: TracedHp) -> None:
686683
"""Set hidden_size hparam for router weights from global hidden_size hparam."""
@@ -692,17 +689,13 @@ class _DynamicSequentialMLP(DynamicModule):
692689
"""A SequentialMLP with dynamic hyperparams."""
693690

694691
def _setup(self):
695-
# Register hparam for number of active experts
696-
# Use num_moe_experts hparam name to match TransformerConfig's name
697-
# Will be shared with _DynamicTopKRouter's hp.
692+
# Register hparam for number of active experts (will be shared with _DynamicTopKRouter's hp)
698693
num_moe_experts = TracedHp(list(range(1, self.num_local_experts + 1)))
699-
self._register_hparam("num_moe_experts", num_moe_experts)
700-
self._register_dynamic_attribute(
701-
"num_local_experts",
702-
lambda mod, val: mod.num_moe_experts, # EP = 1
703-
)
694+
self._register_hparam("num_local_experts", num_moe_experts)
704695

705-
# Convert each individual expert MLP to dynamic
696+
# Convert local_experts list and each individual expert MLP to dynamic modules
697+
self.local_experts = DynamicModuleList.convert(self.local_experts)
698+
self.local_experts.depth = num_moe_experts # Reuse same hparam for depth
706699
for i in range(len(self.local_experts)):
707700
self.local_experts[i] = DMRegistry.convert(self.local_experts[i])
708701

@@ -725,6 +718,11 @@ def set_hidden_size_hp(self, hidden_size: TracedHp) -> None:
725718

726719
def _expert_l2_imp_forward_hook(self, module, input, output):
727720
"""Track expert importance based on L2 norms of expert outputs."""
721+
# Dont aggregate activations from non-max subnets (e.g. from profiling)
722+
num_moe_experts = self.get_hparam("num_local_experts")
723+
if num_moe_experts.active != num_moe_experts.max:
724+
return
725+
728726
# Split output back to per-expert outputs using torch.split
729727
tokens_per_expert_list = input[1].tolist()
730728
# use full precision to avoid overflow
@@ -757,60 +755,50 @@ def _estimate_expert_importance(self) -> TracedHp.Importance:
757755
self._activations["expert_sample_counts"] + 1e-8
758756
)
759757

760-
def _export_drop_experts(self) -> None:
761-
"""Drop experts during export based on active hyperparameter value."""
762-
# Get sorted + trimmed order of experts to keep
763-
active_slice = self.get_hparam("num_moe_experts").active_slice
764-
765-
# Trim experts based on active hparam value
766-
if isinstance(active_slice, slice):
767-
kept_experts = self.local_experts[: active_slice.stop]
768-
else:
769-
kept_experts = [self.local_experts[i] for i in active_slice]
770-
771-
# Replace the ModuleList with pruned experts
772-
self.local_experts = nn.ModuleList(kept_experts)
773-
774758
def export(self) -> torch.nn.Module:
775759
"""Export the dynamic module to a standard SequentialMLP."""
776760
self.hook_handle.remove()
777-
778-
# Drop experts based on active hparam value and export remaining experts
779-
self._export_drop_experts()
780761
for expert in self.local_experts:
781762
expert.export()
782-
783-
super().export()
784-
return self
763+
self.local_experts.export()
764+
return super().export()
785765

786766

787767
@DMRegistry.register({MoELayer: "megatron.core.transformer.moe.moe_layer.MoELayer"})
788768
class _DynamicMoELayer(DynamicModule):
789769
"""A MoELayer with dynamic hyperparams."""
790770

791771
def _setup(self):
792-
# TODO: Add DynamicTokenDispatcher for moe_shared_expert_overlap support
793-
assert not self.shared_expert_overlap, "moe_shared_expert_overlap is not supported yet!"
794-
795772
# Convert to dynamic modules
796773
# Reuse _DynamicSequentialMLP's num_moe_experts hparam for _DynamicTopKRouter's hparam so
797-
# importance estimator is not lost.
774+
# importance estimator and depth hparam is retained.
798775
self.router = DMRegistry.convert(self.router)
799776
self.experts = DMRegistry.convert(self.experts)
800-
num_moe_experts_hp = self.experts.get_hparam("num_moe_experts")
777+
num_moe_experts_hp = self.experts.get_hparam("num_local_experts")
778+
779+
# NOTE: Use num_moe_experts hparam name in top-level module to match TransformerConfig's name
780+
self._register_hparam("num_moe_experts", num_moe_experts_hp)
801781
self._register_dynamic_attribute(
802782
"num_local_experts",
803783
lambda mod, val: num_moe_experts_hp.active, # EP = 1
804784
)
805-
self.router.num_moe_experts = num_moe_experts_hp
785+
self.router.num_experts = num_moe_experts_hp
806786
if self.use_shared_expert:
807787
self.shared_experts = DMRegistry.convert(self.shared_experts)
808788

809789
self._register_dynamic_attribute("local_expert_indices", self._get_local_expert_indices)
790+
# assert self.config.moe_token_dispatcher_type == "alltoall", (
791+
# "Only moe_token_dispatcher_type=='alltoall' is supported!"
792+
# )
793+
# self.token_dispatcher = DMRegistry.convert(self.token_dispatcher)
810794

811795
def _get_local_expert_indices(self, mod: "_DynamicMoELayer", val: list[int]) -> list[int]:
812796
"""Get local expert indices for the current active hparam value."""
813-
return list(range(mod.num_local_experts))
797+
active_slice = self.get_hparam("num_moe_experts").active_slice
798+
if isinstance(active_slice, slice):
799+
return list(range(active_slice.stop))
800+
else:
801+
return active_slice.tolist()
814802

815803
def set_hidden_size_hp(self, hidden_size: TracedHp) -> None:
816804
"""Set hidden size for all MoE components from global hidden_size hparam."""
@@ -824,7 +812,7 @@ def modify(
824812
):
825813
"""Modify MoE hparam choices based on search space config."""
826814
# Modify num_moe_experts hparam choices (applies to both router and experts)
827-
expert_hp = self.experts.get_hparam("num_moe_experts")
815+
expert_hp = self.get_hparam("num_moe_experts")
828816
choices = {int(make_divisible(c, num_moe_experts_divisor)) for c in expert_hp.choices} # type: ignore[arg-type]
829817
expert_hp.choices = list(set(expert_hp.choices) & choices | {expert_hp.original})
830818

@@ -840,8 +828,7 @@ def export(self) -> torch.nn.Module:
840828
self.experts.export()
841829
if self.use_shared_expert:
842830
self.shared_experts.export()
843-
super().export()
844-
return self
831+
return super().export()
845832

846833

847834
# TransformerLayer DynamicModule ###################################################################
@@ -952,19 +939,7 @@ def export(self):
952939
if isinstance(self.mlp, (MLP, MoELayer)):
953940
self.pre_mlp_layernorm.export()
954941
self.mlp.export()
955-
super().export()
956-
return self
957-
958-
def freeze(self):
959-
"""Freeze the dynamic module."""
960-
super().freeze()
961-
if isinstance(self.self_attention, SelfAttention):
962-
self.input_layernorm.freeze()
963-
self.self_attention.freeze()
964-
965-
if isinstance(self.mlp, (MLP, MoELayer)):
966-
self.pre_mlp_layernorm.freeze()
967-
self.mlp.freeze()
942+
return super().export()
968943

969944

970945
# Mamba DynamicModules #############################################################################
@@ -1307,8 +1282,7 @@ def export(self) -> torch.nn.Module:
13071282
self.conv1d.export()
13081283
if self.rmsnorm:
13091284
self.norm.export()
1310-
super().export()
1311-
return self
1285+
return super().export()
13121286

13131287

13141288
class _DynamicMambaLayer(DynamicModule, MambaTransformerLayerMixin):
@@ -1353,13 +1327,7 @@ def export(self):
13531327
self._export_mixin()
13541328
self.mixer.export()
13551329
self.norm.export()
1356-
super().export()
1357-
return self
1358-
1359-
def freeze(self):
1360-
"""Freeze the hyperparameters."""
1361-
self.mixer.freeze()
1362-
super().freeze()
1330+
return super().export()
13631331

13641332

13651333
if HAS_MAMBA:
@@ -1559,12 +1527,6 @@ def _export_drop_layers(self) -> None:
15591527

15601528
def export(self) -> torch.nn.Module:
15611529
"""Export the dynamic module to a torch.nn.Module."""
1562-
# TODO: Improve this!
1563-
# Slice order needs to be reset before exporting since weights are already
1564-
# force assigned and we dont want to sort them again (losing the correct order)
1565-
for n, hp in named_hparams(self, configurable=True):
1566-
hp.enforce_order(None)
1567-
15681530
for handle in self.hook_handles:
15691531
handle.remove()
15701532
self._export_drop_layers()
@@ -1575,14 +1537,7 @@ def export(self) -> torch.nn.Module:
15751537
if is_pipeline_last_stage():
15761538
getattr(self.decoder, self.final_norm_attr_name).export()
15771539
self.output_layer.export()
1578-
super().export()
1579-
return self
1580-
1581-
def freeze(self) -> None:
1582-
"""Freeze the dynamic module."""
1583-
super().freeze()
1584-
for layer in self.decoder.layers:
1585-
layer.freeze()
1540+
return super().export()
15861541

15871542
def get_activations_and_layer_scores(
15881543
self,

modelopt/torch/nas/search_space.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,6 @@ def sort_parameters(self, hps_to_sort: set[str] | None = None, verbose: bool = F
162162
f"{'order' if hp._importance_is_order else 'importance'}={importance}"
163163
)
164164

165-
# now that we have enforced an order we can force reassign all parameters/buffers!
166-
for _, mod in self.named_dynamic_modules():
167-
mod.force_assign()
168-
169165
# go back to old config
170166
self.select(config)
171167

0 commit comments

Comments
 (0)