2727import copy
2828
2929import torch
30+ import torch .nn as nn
3031from pydantic import create_model
3132
3233# isort: off
33- # import nas plugin to check if it is enabled else raises an Exception
34+ # import nas plugin to check if it is enabled else raises an Exception and disables the plugin
3435from modelopt .torch .nas .plugins .megatron import * # noqa: F403
35- from modelopt .torch .nas .plugins .megatron import HAS_MAMBA , _DynamicMCoreLanguageModel
36+ from modelopt .torch .nas .plugins .megatron import (
37+ HAS_MAMBA ,
38+ _DynamicMCoreLanguageModel ,
39+ SUPPORTED_MODELS ,
40+ )
3641# isort: on
3742
3843from modelopt .torch .nas .conversion import NASModeRegistry
3944from modelopt .torch .nas .registry import DMRegistry
40- from modelopt .torch .nas .utils import sort_parameters
45+ from modelopt .torch .nas .utils import get_subnet_config , sort_parameters
4146from modelopt .torch .opt .config import ModeloptBaseConfig , get_kwargs_for_create_model_with_rules
47+ from modelopt .torch .opt .conversion import ApplyModeError
48+ from modelopt .torch .opt .dynamic import DynamicSpace
49+ from modelopt .torch .opt .mode import (
50+ ConvertEntrypoint ,
51+ ConvertReturnType ,
52+ ModeDescriptor ,
53+ RestoreEntrypoint ,
54+ )
4255from modelopt .torch .opt .searcher import BaseSearcher , SearchConfig , SearchStateDict
4356from modelopt .torch .opt .utils import named_hparams
4457from modelopt .torch .utils import print_rank_0
4558
46- from ..fastnas import FastNASModeDescriptor
4759from ..pruning import PruneModeRegistry
4860
4961SUPPORTED_HPARAMS = {
5870 "num_layers" ,
5971}
6072
73+ __all__ = ["MCoreMinitronConfig" , "MCoreMinitronModeDescriptor" , "MCoreMinitronSearcher" ]
74+
6175
6276class MCoreMinitronSearcher (BaseSearcher ):
6377 """Searcher for Minitron pruning algorithm."""
@@ -218,9 +232,48 @@ def run_search(self) -> None:
218232)
219233
220234
235+ def _convert_model_to_dynamic_space (
236+ model : nn .Module , config : ModeloptBaseConfig | None = None
237+ ) -> DynamicSpace :
238+ """Create a dynamic space for the model (in-place)."""
239+ dynamic_space = DynamicSpace (model )
240+ dynamic_space ._should_be_converted = lambda mod : isinstance (mod , tuple (SUPPORTED_MODELS .keys ()))
241+ dynamic_space .convert_to_dynamic (config .model_dump () if config else None , DMRegistry )
242+ if not dynamic_space .is_configurable ():
243+ raise ApplyModeError (
244+ "The model does not contain any configurable hyperparameters! Please check the"
245+ " documentation for modules and config and how to get a configurable model."
246+ )
247+
248+ return dynamic_space
249+
250+
251+ def convert_mcore_minitron (model : nn .Module , config : ModeloptBaseConfig ) -> ConvertReturnType :
252+ """Convert the model to the dynamic search space (in-place) and return the converted model and metadata.
253+
254+ This is a simplified version of convert_fastnas_searchspace that removes the automated recursive tracing
255+ and instead directly converts the top-level model to a DynamicModule. Submodules should not need to be explicitly
256+ converted as that happens from the top-level model.
257+ """
258+ _convert_model_to_dynamic_space (model , config )
259+
260+ # store current config in metadata
261+ metadata = {"subnet_config" : get_subnet_config (model )}
262+
263+ # return converted model as well as metadata
264+ return model , metadata
265+
266+
267+ def restore_mcore_minitron (
268+ model : nn .Module , config : ModeloptBaseConfig , metadata : dict
269+ ) -> nn .Module :
270+ """Restore the model to the original state."""
271+ return convert_mcore_minitron (model , config )[0 ]
272+
273+
221274@NASModeRegistry .register_mode
222275@PruneModeRegistry .register_mode
223- class MCoreMinitronModeDescriptor (FastNASModeDescriptor ):
276+ class MCoreMinitronModeDescriptor (ModeDescriptor ):
224277 """Class to describe the ``"mcore_minitron"`` mode.
225278
226279 The properties of this mode can be inspected via the source code.
@@ -236,7 +289,27 @@ def config_class(self) -> type[ModeloptBaseConfig]:
236289 """Specifies the config class for the mode."""
237290 return MCoreMinitronConfig
238291
292+ @property
293+ def next_modes (self ) -> set [str ] | None :
294+ """Modes that must immediately follow this mode."""
295+ return {"export" , "kd_loss" , "quantize" , "sparse_magnitude" , "sparse_gpt" }
296+
297+ @property
298+ def export_mode (self ) -> str | None :
299+ """The mode that corresponds to the export mode of this mode."""
300+ return "export"
301+
239302 @property
240303 def search_algorithm (self ) -> type [BaseSearcher ]:
241- """Specifies the search algorithm to use for this mode (if any) ."""
304+ """Specifies the search algorithm to use for this mode."""
242305 return MCoreMinitronSearcher
306+
307+ @property
308+ def convert (self ) -> ConvertEntrypoint :
309+ """The mode's entrypoint for converting a model to a search space."""
310+ return convert_mcore_minitron
311+
312+ @property
313+ def restore (self ) -> RestoreEntrypoint :
314+ """The mode's entrypoint for restoring a model with the modelopt_state."""
315+ return restore_mcore_minitron
0 commit comments