Skip to content

Commit ce1c306

Browse files
Remove unnecessary Tracing use in Minitron Search Space generation
Signed-off-by: Keval Morabia <[email protected]>
1 parent b39c73d commit ce1c306

File tree

7 files changed

+93
-92
lines changed

7 files changed

+93
-92
lines changed

modelopt/torch/nas/plugins/megatron.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,11 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
"""Plugin to add NAS/Pruning support for megatron-core GPT model."""
16+
"""Plugin to add NAS/Pruning support for megatron-core Language models like GPT and Mamba."""
1717

1818
import types
1919
from collections.abc import Callable, Sequence
2020
from typing import Any
21-
from warnings import warn
2221

2322
import torch
2423
import torch.nn as nn
@@ -98,7 +97,7 @@
9897
except ImportError:
9998
HAS_MAMBA = False
10099

101-
__all__ = ["drop_mcore_gpt_layers", "drop_mcore_language_model_layers"]
100+
__all__ = ["drop_mcore_language_model_layers"]
102101

103102

104103
class _DynamicParallelLinear(DynamicModule):
@@ -1457,15 +1456,6 @@ def drop_mcore_language_model_layers(model: nn.Module, *, layers_to_drop: list[i
14571456
model.config.num_layers = new_num_layers
14581457

14591458

1460-
def drop_mcore_gpt_layers(model: nn.Module, *, layers_to_drop: list[int]) -> None:
1461-
"""[DEPRECATED] Remove given layers (1-indexed) of the model (works with TP and/or PP)."""
1462-
warn(
1463-
"`drop_mcore_gpt_layers` is deprecated in favor of `drop_mcore_language_model_layers`.",
1464-
DeprecationWarning,
1465-
)
1466-
drop_mcore_language_model_layers(model, layers_to_drop=layers_to_drop)
1467-
1468-
14691459
class MegatronConstraintsFunc(ConstraintsFunc):
14701460
"""A Functor class to check if sub-net satisfied all provided constraints.
14711461

modelopt/torch/prune/plugins/mcore_minitron.py

Lines changed: 79 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,23 +27,35 @@
2727
import copy
2828

2929
import torch
30+
import torch.nn as nn
3031
from pydantic import create_model
3132

3233
# isort: off
33-
# import nas plugin to check if it is enabled else raises an Exception
34+
# import nas plugin to check if it is enabled else raises an Exception and disables the plugin
3435
from modelopt.torch.nas.plugins.megatron import * # noqa: F403
35-
from modelopt.torch.nas.plugins.megatron import HAS_MAMBA, _DynamicMCoreLanguageModel
36+
from modelopt.torch.nas.plugins.megatron import (
37+
HAS_MAMBA,
38+
_DynamicMCoreLanguageModel,
39+
SUPPORTED_MODELS,
40+
)
3641
# isort: on
3742

3843
from modelopt.torch.nas.conversion import NASModeRegistry
3944
from modelopt.torch.nas.registry import DMRegistry
40-
from modelopt.torch.nas.utils import sort_parameters
45+
from modelopt.torch.nas.utils import get_subnet_config, sort_parameters
4146
from modelopt.torch.opt.config import ModeloptBaseConfig, get_kwargs_for_create_model_with_rules
47+
from modelopt.torch.opt.conversion import ApplyModeError
48+
from modelopt.torch.opt.dynamic import DynamicSpace
49+
from modelopt.torch.opt.mode import (
50+
ConvertEntrypoint,
51+
ConvertReturnType,
52+
ModeDescriptor,
53+
RestoreEntrypoint,
54+
)
4255
from modelopt.torch.opt.searcher import BaseSearcher, SearchConfig, SearchStateDict
4356
from modelopt.torch.opt.utils import named_hparams
4457
from modelopt.torch.utils import print_rank_0
4558

46-
from ..fastnas import FastNASModeDescriptor
4759
from ..pruning import PruneModeRegistry
4860

4961
SUPPORTED_HPARAMS = {
@@ -58,6 +70,8 @@
5870
"num_layers",
5971
}
6072

73+
__all__ = ["MCoreMinitronConfig", "MCoreMinitronModeDescriptor", "MCoreMinitronSearcher"]
74+
6175

6276
class MCoreMinitronSearcher(BaseSearcher):
6377
"""Searcher for Minitron pruning algorithm."""
@@ -218,9 +232,48 @@ def run_search(self) -> None:
218232
)
219233

220234

235+
def _convert_model_to_dynamic_space(
236+
model: nn.Module, config: ModeloptBaseConfig | None = None
237+
) -> DynamicSpace:
238+
"""Create a dynamic space for the model (in-place)."""
239+
dynamic_space = DynamicSpace(model)
240+
dynamic_space._should_be_converted = lambda mod: isinstance(mod, tuple(SUPPORTED_MODELS.keys()))
241+
dynamic_space.convert_to_dynamic(config.model_dump() if config else None, DMRegistry)
242+
if not dynamic_space.is_configurable():
243+
raise ApplyModeError(
244+
"The model does not contain any configurable hyperparameters! Please check the"
245+
" documentation for modules and config and how to get a configurable model."
246+
)
247+
248+
return dynamic_space
249+
250+
251+
def convert_mcore_minitron(model: nn.Module, config: ModeloptBaseConfig) -> ConvertReturnType:
252+
"""Convert the model to the dynamic search space (in-place) and return the converted model and metadata.
253+
254+
This is a simplified version of convert_fastnas_searchspace that removes the automated recursive tracing
255+
and instead directly converts the top-level model to a DynamicModule. Submodules should not need to be explicitly
256+
converted as that happens from the top-level model.
257+
"""
258+
_convert_model_to_dynamic_space(model, config)
259+
260+
# store current config in metadata
261+
metadata = {"subnet_config": get_subnet_config(model)}
262+
263+
# return converted model as well as metadata
264+
return model, metadata
265+
266+
267+
def restore_mcore_minitron(
268+
model: nn.Module, config: ModeloptBaseConfig, metadata: dict
269+
) -> nn.Module:
270+
"""Restore the model to the original state."""
271+
return convert_mcore_minitron(model, config)[0]
272+
273+
221274
@NASModeRegistry.register_mode
222275
@PruneModeRegistry.register_mode
223-
class MCoreMinitronModeDescriptor(FastNASModeDescriptor):
276+
class MCoreMinitronModeDescriptor(ModeDescriptor):
224277
"""Class to describe the ``"mcore_minitron"`` mode.
225278
226279
The properties of this mode can be inspected via the source code.
@@ -236,7 +289,27 @@ def config_class(self) -> type[ModeloptBaseConfig]:
236289
"""Specifies the config class for the mode."""
237290
return MCoreMinitronConfig
238291

292+
@property
293+
def next_modes(self) -> set[str] | None:
294+
"""Modes that must immediately follow this mode."""
295+
return {"export", "kd_loss", "quantize", "sparse_magnitude", "sparse_gpt"}
296+
297+
@property
298+
def export_mode(self) -> str | None:
299+
"""The mode that corresponds to the export mode of this mode."""
300+
return "export"
301+
239302
@property
240303
def search_algorithm(self) -> type[BaseSearcher]:
241-
"""Specifies the search algorithm to use for this mode (if any)."""
304+
"""Specifies the search algorithm to use for this mode."""
242305
return MCoreMinitronSearcher
306+
307+
@property
308+
def convert(self) -> ConvertEntrypoint:
309+
"""The mode's entrypoint for converting a model to a search space."""
310+
return convert_mcore_minitron
311+
312+
@property
313+
def restore(self) -> RestoreEntrypoint:
314+
"""The mode's entrypoint for restoring a model with the modelopt_state."""
315+
return restore_mcore_minitron

modelopt/torch/trace/plugins/__init__.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,7 @@
1515

1616
"""Handles tracing plugins for third-party modules."""
1717

18-
import warnings as _warnings
18+
from modelopt.torch.utils import import_plugin
1919

20-
try:
21-
from .megatron import *
22-
23-
except ImportError:
24-
pass
25-
except Exception as e:
26-
_warnings.warn(f"Failed to import megatron plugin due to: {e!r}")
27-
28-
try:
20+
with import_plugin("transformers"):
2921
from .transformers import *
30-
31-
except ImportError:
32-
pass
33-
except Exception as e:
34-
_warnings.warn(f"Failed to import transformers plugin due to: {e!r}")

modelopt/torch/trace/plugins/megatron.py

Lines changed: 0 additions & 36 deletions
This file was deleted.

tests/gpu/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,9 @@
4848
_DynamicVocabParallelEmbedding,
4949
expand_head_indices,
5050
)
51-
from modelopt.torch.nas.search_space import generate_search_space
51+
from modelopt.torch.nas.registry import DMRegistry
5252
from modelopt.torch.opt.utils import named_dynamic_modules, search_space_size
53+
from modelopt.torch.prune.plugins.mcore_minitron import _convert_model_to_dynamic_space
5354
from modelopt.torch.utils import flatten_tree
5455
from modelopt.torch.utils.random import centroid
5556

@@ -178,7 +179,7 @@ def _test_gpt_parameter_sorting(activation_func, rank, size):
178179
m.weight.data = torch.randn_like(m.weight)
179180

180181
model.eval()
181-
search_space = generate_search_space(model)
182+
dynamic_space = _convert_model_to_dynamic_space(model)
182183

183184
# Compute activations for sorting
184185
for _ in range(5):
@@ -188,18 +189,18 @@ def _test_gpt_parameter_sorting(activation_func, rank, size):
188189
prompt_tokens = torch.randint(0, vocab_size, (batch_size, max_sequence_length)).cuda()
189190
y1 = run_mcore_inference(model, prompt_tokens)
190191

191-
search_space.sort_parameters()
192+
mtn.utils.sort_parameters(model)
192193

193194
# check if all ffn_hidden_size, num_heads_per_group, num_query_groups, hidden_size have been sorted
194195
sortable_per_pp = [
195-
n for n, hp in search_space.named_hparams(configurable=True) if hp.importance is not None
196+
n for n, hp in dynamic_space.named_hparams(configurable=True) if hp.importance is not None
196197
]
197198
# 3 hps per layer + 1 for hidden_size (num_layers is not sorted!)
198199
assert len(sortable_per_pp) == 3 * num_layers // size + 1
199200

200201
# Export since sorting force reassigns SelfAttention weights which we dont want to re-sort!
201202
# TODO: ideally we shouldn't need this
202-
search_space.export()
203+
dynamic_space.export(DMRegistry)
203204

204205
# sanity check if the model functionality is preserved after sorting
205206
y2 = run_mcore_inference(model, prompt_tokens)

tests/gpu/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@
4343
_DynamicRowParallelLinear,
4444
_DynamicVocabParallelEmbedding,
4545
)
46-
from modelopt.torch.nas.search_space import generate_search_space
4746
from modelopt.torch.nas.traced_hp import TracedHp
4847
from modelopt.torch.opt.utils import named_dynamic_modules, search_space_size
48+
from modelopt.torch.prune.plugins.mcore_minitron import _convert_model_to_dynamic_space
4949
from modelopt.torch.utils import flatten_tree
5050
from modelopt.torch.utils.random import centroid
5151

@@ -163,7 +163,7 @@ def _test_mamba_parameter_sorting(rank, size):
163163
m.weight.data = torch.randn_like(m.weight)
164164

165165
model.eval()
166-
search_space = generate_search_space(model)
166+
dynamic_space = _convert_model_to_dynamic_space(model)
167167

168168
# Compute activations for sorting
169169
for _ in range(5):
@@ -173,11 +173,11 @@ def _test_mamba_parameter_sorting(rank, size):
173173
prompt_tokens = torch.randint(0, vocab_size, (batch_size, max_sequence_length)).cuda()
174174
y1 = run_mcore_inference(model, prompt_tokens)
175175

176-
search_space.sort_parameters()
176+
dynamic_space.sort_parameters()
177177

178178
# check if all mamba_num_heads, mamba_head_dim, hidden_size have been sorted
179179
sortable_per_pp = [
180-
n for n, hp in search_space.named_hparams(configurable=True) if hp.importance is not None
180+
n for n, hp in dynamic_space.named_hparams(configurable=True) if hp.importance is not None
181181
]
182182
# 2 mamba hps per layer + 1 for hidden_size (num_layers is not sorted!)
183183
assert len(sortable_per_pp) == 2 * num_layers // size + 1

tests/unit/torch/trace/test_symbol.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -139,20 +139,6 @@ def test_sym_map_registry():
139139
except ImportError:
140140
pass
141141

142-
try:
143-
from megatron.core.models.gpt import GPTModel
144-
145-
mods_in_registry.add(GPTModel)
146-
except ImportError:
147-
pass
148-
149-
try:
150-
from megatron.core.models.mamba import MambaModel
151-
152-
mods_in_registry.add(MambaModel)
153-
except ImportError:
154-
pass
155-
156142
not_a_leaf = {nn.Sequential}
157143
dependent_registry = set()
158144

0 commit comments

Comments
 (0)