Skip to content

Commit 5b62318

Browse files
DynamicModuleList, moe_ffn hparam, remove force_assign
Signed-off-by: Keval Morabia <[email protected]>
1 parent 050e1a5 commit 5b62318

File tree

11 files changed

+111
-109
lines changed

11 files changed

+111
-109
lines changed

examples/megatron-lm/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ Pruning is supported for GPT and Mamba models in Pipeline Parallel mode. Availab
121121
- `TARGET_MAMBA_NUM_HEADS`
122122
- `TARGET_MAMBA_HEAD_DIM`
123123
- `TARGET_NUM_MOE_EXPERTS`
124+
- `TARGET_MOE_FFN_HIDDEN_SIZE`
124125
- `TARGET_MOE_SHARED_EXPERT_INTERMEDIATE_SIZE`
125126
- `TARGET_NUM_LAYERS`
126127
- `LAYERS_TO_DROP` (comma separated, 1-indexed list of layer numbers to directly drop)

examples/pruning/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ If your model parameters are already sorted, you can skip the sorting step by se
8989

9090
| **Algorithm** | **Model** | **Pruning Constraints** |
9191
| :---: | :---: | :---: |
92-
| Minitron | Megatron-core / NeMo based GPT / Mamba / MoE / Hybrid Models<sup>1</sup> | Export config with width (`hidden_size`, `ffn_hidden_size`, `num_attention_heads`, `num_query_groups`, `mamba_num_heads`, `mamba_head_dim`, `num_moe_experts`, `moe_shared_expert_intermediate_size`) and/or depth (`num_layers`) values |
92+
| Minitron | Megatron-core / NeMo based GPT / Mamba / MoE / Hybrid Models<sup>1</sup> | Export config with width (`hidden_size`, `ffn_hidden_size`, `num_attention_heads`, `num_query_groups`, `mamba_num_heads`, `mamba_head_dim`, `num_moe_experts`, `moe_ffn_hidden_size`, `moe_shared_expert_intermediate_size`) and/or depth (`num_layers`) values |
9393
| FastNAS | Computer Vision models | flops, parameters |
9494
| GradNAS | HuggingFace BERT, GPT-J | flops, parameters |
9595

modelopt/torch/nas/modules/container.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from ..registry import DMRegistry
2727
from ..traced_hp import TracedHp
2828

29-
__all__ = ["_DynamicSequential"]
29+
__all__ = ["DynamicModuleList", "_DynamicSequential"]
3030

3131

3232
def _activate_depth(func: Callable) -> Callable:
@@ -97,3 +97,35 @@ def modify(self, *, min_depth: int = 0):
9797
"""
9898
hp = self.get_hparam("depth")
9999
hp.choices = [d for d in hp.choices if d >= min_depth]
100+
101+
102+
# NOTE: We provide a parent class since we do not register to DMRegistry and explicitly convert a module if needed.
103+
class DynamicModuleList(DynamicModule, nn.ModuleList):
104+
"""An ``nn.ModuleList`` container with dynamic hyperparams and variable ``depth``.
105+
106+
Unlike _DynamicSequential, this module supports sorting/reordering of modules based on
107+
importance in addition to variable depth.
108+
"""
109+
110+
def _setup(self):
111+
# register hyperparameters
112+
self._register_hparam("depth", TracedHp(list(range(1, len(self) + 1))))
113+
114+
# register _modules as a dynamic attribute
115+
self._register_dynamic_attribute("_modules", self._get_modules)
116+
117+
@staticmethod
118+
def _get_modules(mod: "DynamicModuleList", modules: dict) -> dict:
119+
"""Get modules with dynamic depth and ordering applied based on active_slice."""
120+
hp = mod.get_hparam("depth")
121+
active_slice = hp.active_slice
122+
123+
items = list(modules.items())
124+
125+
if isinstance(active_slice, slice):
126+
active_items = items[active_slice]
127+
else:
128+
active_items = [items[idx] for idx in active_slice.tolist()]
129+
130+
# Re-create dict with keys as str(index) from 0 to len(active_items)
131+
return {str(i): module for i, (_, module) in enumerate(active_items)}

modelopt/torch/nas/plugins/megatron.py

Lines changed: 21 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"""Plugin to add NAS/Pruning support for megatron-core Language models like GPT and Mamba."""
1717

1818
import types
19+
from abc import ABC
1920
from collections.abc import Callable, Sequence
2021
from typing import Any
2122

@@ -52,10 +53,10 @@
5253
from megatron.core.transformer.moe.shared_experts import SharedExpertMLP
5354
from megatron.core.transformer.transformer_layer import TransformerLayer
5455

56+
from modelopt.torch.nas.modules import DynamicModuleList
5557
from modelopt.torch.opt.dynamic import DynamicModule
5658
from modelopt.torch.opt.hparam import HPType
5759
from modelopt.torch.opt.searcher import ConstraintsDict
58-
from modelopt.torch.opt.utils import named_hparams
5960
from modelopt.torch.trace import Symbol
6061
from modelopt.torch.utils import distributed as dist
6162
from modelopt.torch.utils import (
@@ -201,6 +202,8 @@ def _setup(self):
201202
)
202203
if isinstance(self, SharedExpertMLP):
203204
self.hparam_name = "moe_shared_expert_intermediate_size"
205+
elif self.config.num_moe_experts is not None:
206+
self.hparam_name = "moe_ffn_hidden_size"
204207
else:
205208
self.hparam_name = "ffn_hidden_size"
206209
self.linear_fc1 = DMRegistry.convert(self.linear_fc1)
@@ -650,8 +653,9 @@ def export(self) -> torch.nn.Module:
650653

651654

652655
# MoE DynamicModules ###############################################################################
656+
# Add ABC to avoid TypeError: object layout differs (because parent if TopKRouter inherits from ABC)
653657
@DMRegistry.register({TopKRouter: "megatron.core.transformer.moe.router.TopKRouter"})
654-
class _DynamicTopKRouter(DynamicModule):
658+
class _DynamicTopKRouter(DynamicModule, ABC):
655659
"""A TopKRouter with dynamic hyperparams."""
656660

657661
def _setup(self):
@@ -660,11 +664,11 @@ def _setup(self):
660664
# Register num_moe_experts hparam name to match TransformerConfig's name.
661665
# Will be overridden by _DynamicSequentialMLP's hp.
662666
self._register_hparam("num_moe_experts", TracedHp(list(range(1, self.weight.shape[0] + 1))))
663-
self._register_dynamic_attribute("num_experts", lambda mod, val: mod.num_moe_experts)
664667
# Register hidden_size reference (will be overridden by _DynamicMoELayer's hidden_size)
665668
self._register_hparam("hidden_size", TracedHp(list(range(1, self.weight.shape[1] + 1))))
666669

667670
# Register dynamic attributes
671+
self._register_dynamic_attribute("num_experts", lambda mod, val: mod.num_moe_experts)
668672
self._register_dynamic_attribute("weight", self._get_router_weight)
669673
if self.enable_expert_bias:
670674
self._register_dynamic_attribute("expert_bias", self._get_slice_by_num_moe_experts)
@@ -702,7 +706,9 @@ def _setup(self):
702706
lambda mod, val: mod.num_moe_experts, # EP = 1
703707
)
704708

705-
# Convert each individual expert MLP to dynamic
709+
# Convert local_experts list and each individual expert MLP to dynamic modules
710+
self.local_experts = DynamicModuleList.convert(self.local_experts)
711+
self.local_experts.depth = num_moe_experts # Reuse same hparam for depth
706712
for i in range(len(self.local_experts)):
707713
self.local_experts[i] = DMRegistry.convert(self.local_experts[i])
708714

@@ -725,6 +731,11 @@ def set_hidden_size_hp(self, hidden_size: TracedHp) -> None:
725731

726732
def _expert_l2_imp_forward_hook(self, module, input, output):
727733
"""Track expert importance based on L2 norms of expert outputs."""
734+
# Dont aggregate activations from non-max subnets (e.g. from profiling)
735+
num_moe_experts = self.get_hparam("num_moe_experts")
736+
if num_moe_experts.active != num_moe_experts.max:
737+
return
738+
728739
# Split output back to per-expert outputs using torch.split
729740
tokens_per_expert_list = input[1].tolist()
730741
# use full precision to avoid overflow
@@ -757,26 +768,11 @@ def _estimate_expert_importance(self) -> TracedHp.Importance:
757768
self._activations["expert_sample_counts"] + 1e-8
758769
)
759770

760-
def _export_drop_experts(self) -> None:
761-
"""Drop experts during export based on active hyperparameter value."""
762-
# Get sorted + trimmed order of experts to keep
763-
active_slice = self.get_hparam("num_moe_experts").active_slice
764-
765-
# Trim experts based on active hparam value
766-
if isinstance(active_slice, slice):
767-
kept_experts = self.local_experts[: active_slice.stop]
768-
else:
769-
kept_experts = [self.local_experts[i] for i in active_slice]
770-
771-
# Replace the ModuleList with pruned experts
772-
self.local_experts = nn.ModuleList(kept_experts)
773-
774771
def export(self) -> torch.nn.Module:
775772
"""Export the dynamic module to a standard SequentialMLP."""
776773
self.hook_handle.remove()
777774

778-
# Drop experts based on active hparam value and export remaining experts
779-
self._export_drop_experts()
775+
self.local_experts.export()
780776
for expert in self.local_experts:
781777
expert.export()
782778

@@ -789,9 +785,6 @@ class _DynamicMoELayer(DynamicModule):
789785
"""A MoELayer with dynamic hyperparams."""
790786

791787
def _setup(self):
792-
# TODO: Add DynamicTokenDispatcher for moe_shared_expert_overlap support
793-
assert not self.shared_expert_overlap, "moe_shared_expert_overlap is not supported yet!"
794-
795788
# Convert to dynamic modules
796789
# Reuse _DynamicSequentialMLP's num_moe_experts hparam for _DynamicTopKRouter's hparam so
797790
# importance estimator is not lost.
@@ -810,7 +803,11 @@ def _setup(self):
810803

811804
def _get_local_expert_indices(self, mod: "_DynamicMoELayer", val: list[int]) -> list[int]:
812805
"""Get local expert indices for the current active hparam value."""
813-
return list(range(mod.num_local_experts))
806+
active_slice = self.experts.get_hparam("num_moe_experts").active_slice
807+
if isinstance(active_slice, slice):
808+
return list(range(active_slice.stop))
809+
else:
810+
return active_slice.tolist()
814811

815812
def set_hidden_size_hp(self, hidden_size: TracedHp) -> None:
816813
"""Set hidden size for all MoE components from global hidden_size hparam."""
@@ -955,17 +952,6 @@ def export(self):
955952
super().export()
956953
return self
957954

958-
def freeze(self):
959-
"""Freeze the dynamic module."""
960-
super().freeze()
961-
if isinstance(self.self_attention, SelfAttention):
962-
self.input_layernorm.freeze()
963-
self.self_attention.freeze()
964-
965-
if isinstance(self.mlp, (MLP, MoELayer)):
966-
self.pre_mlp_layernorm.freeze()
967-
self.mlp.freeze()
968-
969955

970956
# Mamba DynamicModules #############################################################################
971957
class MambaNumHeadsHp(TracedHp):
@@ -1356,11 +1342,6 @@ def export(self):
13561342
super().export()
13571343
return self
13581344

1359-
def freeze(self):
1360-
"""Freeze the hyperparameters."""
1361-
self.mixer.freeze()
1362-
super().freeze()
1363-
13641345

13651346
if HAS_MAMBA:
13661347
DMRegistry.register({ExtendedRMSNorm: "megatron.core.ssm.mamba_mixer.ExtendedRMSNorm"})(
@@ -1559,12 +1540,6 @@ def _export_drop_layers(self) -> None:
15591540

15601541
def export(self) -> torch.nn.Module:
15611542
"""Export the dynamic module to a torch.nn.Module."""
1562-
# TODO: Improve this!
1563-
# Slice order needs to be reset before exporting since weights are already
1564-
# force assigned and we dont want to sort them again (losing the correct order)
1565-
for n, hp in named_hparams(self, configurable=True):
1566-
hp.enforce_order(None)
1567-
15681543
for handle in self.hook_handles:
15691544
handle.remove()
15701545
self._export_drop_layers()
@@ -1578,12 +1553,6 @@ def export(self) -> torch.nn.Module:
15781553
super().export()
15791554
return self
15801555

1581-
def freeze(self) -> None:
1582-
"""Freeze the dynamic module."""
1583-
super().freeze()
1584-
for layer in self.decoder.layers:
1585-
layer.freeze()
1586-
15871556
def get_activations_and_layer_scores(
15881557
self,
15891558
) -> tuple[list[dict[str, torch.Tensor]], dict[int, torch.Tensor]]:

modelopt/torch/nas/search_space.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,6 @@ def sort_parameters(self, hps_to_sort: set[str] | None = None, verbose: bool = F
162162
f"{'order' if hp._importance_is_order else 'importance'}={importance}"
163163
)
164164

165-
# now that we have enforced an order we can force reassign all parameters/buffers!
166-
for _, mod in self.named_dynamic_modules():
167-
mod.force_assign()
168-
169165
# go back to old config
170166
self.select(config)
171167

modelopt/torch/opt/dynamic.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -586,28 +586,6 @@ def export(self) -> nn.Module:
586586

587587
return self
588588

589-
@torch.no_grad()
590-
def force_assign(self):
591-
"""Force re-assign all dynamic attributes to their current values.
592-
593-
.. warning::
594-
595-
Note that this method overwrites the actual buffers and parameters! Only use in
596-
specific circumstances!!
597-
"""
598-
# force-reassign all dynamic attributes
599-
for name in self._get_dm_attribute_manager().da_keys():
600-
val = getattr(self, name)
601-
if isinstance(val, torch.Tensor):
602-
val = val.detach().clone()
603-
if name in self._parameters:
604-
val = val if val is None else Parameter(val)
605-
self.register_parameter(name, val)
606-
elif name in self._buffers:
607-
self.register_buffer(name, val)
608-
else:
609-
setattr(self, name, val)
610-
611589
@classmethod
612590
@torch.no_grad()
613591
def convert(cls, module: nn.Module) -> "DynamicModule":

modelopt/torch/prune/plugins/mcore_minitron.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,9 @@
7171
"mamba_num_heads",
7272
"mamba_head_dim",
7373
# MoE
74-
"num_moe_experts",
74+
"moe_ffn_hidden_size",
7575
"moe_shared_expert_intermediate_size",
76+
"num_moe_experts",
7677
# 2. Depth pruning
7778
"num_layers",
7879
}

tests/_test_utils/torch/megatron/models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def get_mcore_gpt_model(
146146
use_te: bool = False,
147147
# MoE-specific parameters
148148
moe_grouped_gemm: bool = False,
149+
moe_ffn_hidden_size: int | None = None,
149150
moe_shared_expert_intermediate_size: int | None = None,
150151
num_moe_experts: int | None = None,
151152
) -> GPTModel:
@@ -188,6 +189,7 @@ def squared_relu(x):
188189
# MoE-specific parameters
189190
moe_grouped_gemm=moe_grouped_gemm,
190191
moe_router_dtype="fp32",
192+
moe_ffn_hidden_size=moe_ffn_hidden_size,
191193
moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size,
192194
num_moe_experts=num_moe_experts,
193195
)

tests/gpu/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
_DynamicVocabParallelEmbedding,
4949
expand_head_indices,
5050
)
51-
from modelopt.torch.nas.registry import DMRegistry
5251
from modelopt.torch.opt.utils import named_dynamic_modules, search_space_size
5352
from modelopt.torch.prune.plugins.mcore_minitron import _convert_model_to_dynamic_space
5453
from modelopt.torch.utils import flatten_tree
@@ -198,10 +197,6 @@ def _test_gpt_parameter_sorting(activation_func, rank, size):
198197
# 3 hps per layer + 1 for hidden_size (num_layers is not sorted!)
199198
assert len(sortable_per_pp) == 3 * num_layers // size + 1
200199

201-
# Export since sorting force reassigns SelfAttention weights which we dont want to re-sort!
202-
# TODO: ideally we shouldn't need this
203-
dynamic_space.export(DMRegistry)
204-
205200
# sanity check if the model functionality is preserved after sorting
206201
y2 = run_mcore_inference(model, prompt_tokens)
207202

0 commit comments

Comments
 (0)