Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 2 additions & 12 deletions modelopt/torch/nas/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Plugin to add NAS/Pruning support for megatron-core GPT model."""
"""Plugin to add NAS/Pruning support for megatron-core Language models like GPT and Mamba."""

import types
from collections.abc import Callable, Sequence
from typing import Any
from warnings import warn

import torch
import torch.nn as nn
Expand Down Expand Up @@ -98,7 +97,7 @@
except ImportError:
HAS_MAMBA = False

__all__ = ["drop_mcore_gpt_layers", "drop_mcore_language_model_layers"]
__all__ = ["drop_mcore_language_model_layers"]


class _DynamicParallelLinear(DynamicModule):
Expand Down Expand Up @@ -1457,15 +1456,6 @@ def drop_mcore_language_model_layers(model: nn.Module, *, layers_to_drop: list[i
model.config.num_layers = new_num_layers


def drop_mcore_gpt_layers(model: nn.Module, *, layers_to_drop: list[int]) -> None:
"""[DEPRECATED] Remove given layers (1-indexed) of the model (works with TP and/or PP)."""
warn(
"`drop_mcore_gpt_layers` is deprecated in favor of `drop_mcore_language_model_layers`.",
DeprecationWarning,
)
drop_mcore_language_model_layers(model, layers_to_drop=layers_to_drop)


class MegatronConstraintsFunc(ConstraintsFunc):
"""A Functor class to check if sub-net satisfied all provided constraints.

Expand Down
85 changes: 79 additions & 6 deletions modelopt/torch/prune/plugins/mcore_minitron.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,35 @@
import copy

import torch
import torch.nn as nn
from pydantic import create_model

# isort: off
# import nas plugin to check if it is enabled else raises an Exception
# import nas plugin to check if it is enabled else raises an Exception and disables the plugin
from modelopt.torch.nas.plugins.megatron import * # noqa: F403
from modelopt.torch.nas.plugins.megatron import HAS_MAMBA, _DynamicMCoreLanguageModel
from modelopt.torch.nas.plugins.megatron import (
HAS_MAMBA,
_DynamicMCoreLanguageModel,
SUPPORTED_MODELS,
)
# isort: on

from modelopt.torch.nas.conversion import NASModeRegistry
from modelopt.torch.nas.registry import DMRegistry
from modelopt.torch.nas.utils import sort_parameters
from modelopt.torch.nas.utils import get_subnet_config, sort_parameters
from modelopt.torch.opt.config import ModeloptBaseConfig, get_kwargs_for_create_model_with_rules
from modelopt.torch.opt.conversion import ApplyModeError
from modelopt.torch.opt.dynamic import DynamicSpace
from modelopt.torch.opt.mode import (
ConvertEntrypoint,
ConvertReturnType,
ModeDescriptor,
RestoreEntrypoint,
)
from modelopt.torch.opt.searcher import BaseSearcher, SearchConfig, SearchStateDict
from modelopt.torch.opt.utils import named_hparams
from modelopt.torch.utils import print_rank_0

from ..fastnas import FastNASModeDescriptor
from ..pruning import PruneModeRegistry

SUPPORTED_HPARAMS = {
Expand All @@ -58,6 +70,8 @@
"num_layers",
}

__all__ = ["MCoreMinitronConfig", "MCoreMinitronModeDescriptor", "MCoreMinitronSearcher"]


class MCoreMinitronSearcher(BaseSearcher):
"""Searcher for Minitron pruning algorithm."""
Expand Down Expand Up @@ -218,9 +232,48 @@ def run_search(self) -> None:
)


def _convert_model_to_dynamic_space(
model: nn.Module, config: ModeloptBaseConfig | None = None
) -> DynamicSpace:
"""Create a dynamic space for the model (in-place)."""
dynamic_space = DynamicSpace(model)
dynamic_space._should_be_converted = lambda mod: isinstance(mod, tuple(SUPPORTED_MODELS.keys()))
dynamic_space.convert_to_dynamic(config.model_dump() if config else None, DMRegistry)
if not dynamic_space.is_configurable():
raise ApplyModeError(
"The model does not contain any configurable hyperparameters! Please check the"
" documentation for modules and config and how to get a configurable model."
)

return dynamic_space


def convert_mcore_minitron(model: nn.Module, config: ModeloptBaseConfig) -> ConvertReturnType:
"""Convert the model to the dynamic search space (in-place) and return the converted model and metadata.

This is a simplified version of convert_fastnas_searchspace that removes the automated recursive tracing
and instead directly converts the top-level model to a DynamicModule. Submodules should not need to be explicitly
converted as that happens from the top-level model.
"""
_convert_model_to_dynamic_space(model, config)

# store current config in metadata
metadata = {"subnet_config": get_subnet_config(model)}

# return converted model as well as metadata
return model, metadata


def restore_mcore_minitron(
model: nn.Module, config: ModeloptBaseConfig, metadata: dict
) -> nn.Module:
"""Restore the model to the original state."""
return convert_mcore_minitron(model, config)[0]


@NASModeRegistry.register_mode
@PruneModeRegistry.register_mode
class MCoreMinitronModeDescriptor(FastNASModeDescriptor):
class MCoreMinitronModeDescriptor(ModeDescriptor):
"""Class to describe the ``"mcore_minitron"`` mode.

The properties of this mode can be inspected via the source code.
Expand All @@ -236,7 +289,27 @@ def config_class(self) -> type[ModeloptBaseConfig]:
"""Specifies the config class for the mode."""
return MCoreMinitronConfig

@property
def next_modes(self) -> set[str] | None:
"""Modes that must immediately follow this mode."""
return {"export", "kd_loss", "quantize", "sparse_magnitude", "sparse_gpt"}

@property
def export_mode(self) -> str | None:
"""The mode that corresponds to the export mode of this mode."""
return "export"

@property
def search_algorithm(self) -> type[BaseSearcher]:
"""Specifies the search algorithm to use for this mode (if any)."""
"""Specifies the search algorithm to use for this mode."""
return MCoreMinitronSearcher

@property
def convert(self) -> ConvertEntrypoint:
"""The mode's entrypoint for converting a model to a search space."""
return convert_mcore_minitron

@property
def restore(self) -> RestoreEntrypoint:
"""The mode's entrypoint for restoring a model with the modelopt_state."""
return restore_mcore_minitron
17 changes: 2 additions & 15 deletions modelopt/torch/trace/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,7 @@

"""Handles tracing plugins for third-party modules."""

import warnings as _warnings
from modelopt.torch.utils import import_plugin

try:
from .megatron import *

except ImportError:
pass
except Exception as e:
_warnings.warn(f"Failed to import megatron plugin due to: {e!r}")

try:
with import_plugin("transformers"):
from .transformers import *

except ImportError:
pass
except Exception as e:
_warnings.warn(f"Failed to import transformers plugin due to: {e!r}")
36 changes: 0 additions & 36 deletions modelopt/torch/trace/plugins/megatron.py

This file was deleted.

11 changes: 6 additions & 5 deletions tests/gpu/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@
_DynamicVocabParallelEmbedding,
expand_head_indices,
)
from modelopt.torch.nas.search_space import generate_search_space
from modelopt.torch.nas.registry import DMRegistry
from modelopt.torch.opt.utils import named_dynamic_modules, search_space_size
from modelopt.torch.prune.plugins.mcore_minitron import _convert_model_to_dynamic_space
from modelopt.torch.utils import flatten_tree
from modelopt.torch.utils.random import centroid

Expand Down Expand Up @@ -178,7 +179,7 @@ def _test_gpt_parameter_sorting(activation_func, rank, size):
m.weight.data = torch.randn_like(m.weight)

model.eval()
search_space = generate_search_space(model)
dynamic_space = _convert_model_to_dynamic_space(model)

# Compute activations for sorting
for _ in range(5):
Expand All @@ -188,18 +189,18 @@ def _test_gpt_parameter_sorting(activation_func, rank, size):
prompt_tokens = torch.randint(0, vocab_size, (batch_size, max_sequence_length)).cuda()
y1 = run_mcore_inference(model, prompt_tokens)

search_space.sort_parameters()
mtn.utils.sort_parameters(model)

# check if all ffn_hidden_size, num_heads_per_group, num_query_groups, hidden_size have been sorted
sortable_per_pp = [
n for n, hp in search_space.named_hparams(configurable=True) if hp.importance is not None
n for n, hp in dynamic_space.named_hparams(configurable=True) if hp.importance is not None
]
# 3 hps per layer + 1 for hidden_size (num_layers is not sorted!)
assert len(sortable_per_pp) == 3 * num_layers // size + 1

# Export since sorting force reassigns SelfAttention weights which we dont want to re-sort!
# TODO: ideally we shouldn't need this
search_space.export()
dynamic_space.export(DMRegistry)

# sanity check if the model functionality is preserved after sorting
y2 = run_mcore_inference(model, prompt_tokens)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@
_DynamicRowParallelLinear,
_DynamicVocabParallelEmbedding,
)
from modelopt.torch.nas.search_space import generate_search_space
from modelopt.torch.nas.traced_hp import TracedHp
from modelopt.torch.opt.utils import named_dynamic_modules, search_space_size
from modelopt.torch.prune.plugins.mcore_minitron import _convert_model_to_dynamic_space
from modelopt.torch.utils import flatten_tree
from modelopt.torch.utils.random import centroid

Expand Down Expand Up @@ -163,7 +163,7 @@ def _test_mamba_parameter_sorting(rank, size):
m.weight.data = torch.randn_like(m.weight)

model.eval()
search_space = generate_search_space(model)
dynamic_space = _convert_model_to_dynamic_space(model)

# Compute activations for sorting
for _ in range(5):
Expand All @@ -173,11 +173,11 @@ def _test_mamba_parameter_sorting(rank, size):
prompt_tokens = torch.randint(0, vocab_size, (batch_size, max_sequence_length)).cuda()
y1 = run_mcore_inference(model, prompt_tokens)

search_space.sort_parameters()
dynamic_space.sort_parameters()

# check if all mamba_num_heads, mamba_head_dim, hidden_size have been sorted
sortable_per_pp = [
n for n, hp in search_space.named_hparams(configurable=True) if hp.importance is not None
n for n, hp in dynamic_space.named_hparams(configurable=True) if hp.importance is not None
]
# 2 mamba hps per layer + 1 for hidden_size (num_layers is not sorted!)
assert len(sortable_per_pp) == 2 * num_layers // size + 1
Expand Down
14 changes: 0 additions & 14 deletions tests/unit/torch/trace/test_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,20 +139,6 @@ def test_sym_map_registry():
except ImportError:
pass

try:
from megatron.core.models.gpt import GPTModel

mods_in_registry.add(GPTModel)
except ImportError:
pass

try:
from megatron.core.models.mamba import MambaModel

mods_in_registry.add(MambaModel)
except ImportError:
pass

not_a_leaf = {nn.Sequential}
dependent_registry = set()

Expand Down