Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion examples/megatron-lm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ Coming soon ...

### ⭐ Pruning

Checkout pruning [getting started section](../pruning/README.md#getting-started) and [guidelines](../pruning/README.md#pruning-guidelines) for configuring pruning parameters in the pruning README.

Pruning is supported for GPT and Mamba models in Pipeline Parallel mode. Available pruning options are:

- `TARGET_FFN_HIDDEN_SIZE`
Expand All @@ -121,14 +123,20 @@ Pruning is supported for GPT and Mamba models in Pipeline Parallel mode. Availab
- `TARGET_NUM_LAYERS`
- `LAYERS_TO_DROP` (comma separated, 1-indexed list of layer numbers to directly drop)

Example for depth pruning Qwen3-8B from 36 to 24 layers:

```sh
PP=1 \
TARGET_NUM_LAYERS=24 \
HF_MODEL_CKPT=<pretrained_model_name_or_path> \
MLM_MODEL_SAVE=/tmp/Qwen3-8B-DPruned \
MLM_MODEL_SAVE=Qwen3-8B-Pruned \
bash megatron-lm/examples/post_training/modelopt/prune.sh qwen/Qwen3-8B
```

> [!TIP]
> If number of layers in the model is not divisible by pipeline parallel size (PP), you can configure uneven
> PP by setting `MLM_EXTRA_ARGS="--decoder-first-pipeline-num-layers <X> --decoder-last-pipeline-num-layers <Y>"`

## Learn More About Configuration

For simplicity, we use `shell` scripts and variables as arguments. Each script has at least 1 positional
Expand Down
123 changes: 11 additions & 112 deletions modelopt/torch/nas/autonas.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,7 @@
from pydantic import create_model
from torch.nn.modules.batchnorm import _BatchNorm

from modelopt.torch.opt.config import (
ModeloptBaseConfig,
ModeloptField,
get_kwargs_for_create_model_with_rules,
)
from modelopt.torch.opt.config import ModeloptBaseConfig, get_kwargs_for_create_model_with_rules
from modelopt.torch.opt.conversion import ApplyModeError, ModelLikeModule
from modelopt.torch.opt.mode import (
ConvertEntrypoint,
Expand All @@ -56,34 +52,35 @@
stats,
torch_detach,
torch_to,
unwrap_model,
)

from .algorithms import ConstraintsFunc, get_constraints_func
from .conversion import NASModeRegistry
from .patch import PatchData, PatchManager, _modelopt_eval_recursion_guard, prep_for_eval
from .registry import DMRegistry
from .search_space import SearchSpace, generate_search_space
from .utils import MODELOPT_BN_CALIB_ITERS, MODELOPT_QUEUE_MAXLEN, get_subnet_config, sample, select
from .search_space import generate_search_space
from .utils import get_subnet_config, sample, select

__all__ = [
"AutoNASConfig",
"AutoNASModeDescriptor",
"AutoNASPatchManager",
"EvolveSearcher",
"ExportConfig",
"ExportModeDescriptor",
"IterativeSearcher",
"RandomSearcher",
"convert_autonas_searchspace",
"convert_searchspace",
"export_searchspace",
"restore_autonas_searchspace",
"restore_export",
"restore_searchspace",
"update_autonas_metadata",
]

# we have two different numbers here since during training it might take longer to stabilize
MODELOPT_QUEUE_MAXLEN = 50 # indicates length of modelopt data queue for BN calib
MODELOPT_BN_CALIB_ITERS = (
100 # indicates # iters in train mode 'til we trust BN stats without calib
)


def _get_ratio_list():
return (0.5, 0.67, 1.0)
Expand Down Expand Up @@ -132,25 +129,6 @@ def _norm_lin_config():
)


class ExportConfig(ModeloptBaseConfig):
"""Configuration for the export mode.

This mode is used to export a model after NAS search.
"""

strict: bool = ModeloptField(
default=True,
title="Strict export",
description="Enforces that the subnet configuration must exactly match during export.",
)

calib: bool = ModeloptField(
default=False,
title="Calibration",
description="Whether to calibrate the subnet before exporting.",
)


class AutoNASPatchManager(PatchManager):
"""A class to handle the monkey patching of the model for automode."""

Expand Down Expand Up @@ -676,48 +654,6 @@ def update_autonas_metadata(
metadata["subnet_config"] = get_subnet_config(model)


def export_searchspace(model: nn.Module, config: ExportConfig) -> ConvertReturnType:
"""Export a subnet configuration of the search space to a regular model."""
# sanity check to avoid DP/DDP here in the entrypoint
model = unwrap_model(model, raise_error=True)

# store config from model if we can find it for a future convert/restore process
subnet_config = get_subnet_config(model)

# Check for patching and calibration
if PatchManager.is_patched(model):
manager = PatchManager.get_manager(model)
if config.calib:
manager.call_post_eval()
manager.unpatch()

# export model in-place
model = SearchSpace(model).export()

# construct metadata
metadata = {
"subnet_config": subnet_config,
}

return model, metadata


def restore_export(model: nn.Module, config: ExportConfig, metadata: MetadataDict) -> nn.Module:
"""Restore & export the subnet configuration of the search space to a regular model."""
# select subnet config provided in metadata
select(model, metadata["subnet_config"], strict=config["strict"])

# run export
model, metadata_new = export_searchspace(model, config)

# double check metadata
unmatched_keys = compare_dict(metadata, metadata_new)
if unmatched_keys:
raise ApplyModeError(f"Unmatched metadata={unmatched_keys}!")

return model


@NASModeRegistry.register_mode
class AutoNASModeDescriptor(ModeDescriptor):
"""Class to describe the ``"autonas"`` mode.
Expand All @@ -738,12 +674,12 @@ def config_class(self) -> type[ModeloptBaseConfig]:
@property
def next_modes(self) -> set[str] | None:
"""Modes that must immediately follow this mode."""
return {"export", "kd_loss", "quantize", "sparse_magnitude", "sparse_gpt"}
return {"export_nas", "kd_loss", "quantize", "sparse_magnitude", "sparse_gpt"}

@property
def export_mode(self) -> str | None:
"""The mode that corresponds to the export mode of this mode."""
return "export"
return "export_nas"

@property
def search_algorithm(self) -> type[BaseSearcher]:
Expand All @@ -769,40 +705,3 @@ def update_for_save(self) -> UpdateEntrypoint:
def update_for_new_mode(self) -> UpdateEntrypoint:
"""The mode's entrypoint for updating the models state before new mode."""
return update_autonas_metadata


@NASModeRegistry.register_mode
class ExportModeDescriptor(ModeDescriptor):
"""Class to describe the ``"export"`` mode.

The properties of this mode can be inspected via the source code.
"""

@property
def name(self) -> str:
"""Returns the value (str representation) of the mode."""
return "export"

@property
def config_class(self) -> type[ModeloptBaseConfig]:
"""Specifies the config class for the mode."""
return ExportConfig

@property
def is_export_mode(self) -> bool:
"""Whether the mode is an export mode.

Returns:
True if the mode is an export mode, False otherwise. Defaults to False.
"""
return True

@property
def convert(self) -> ConvertEntrypoint:
"""The mode's entrypoint for converting a model."""
return export_searchspace

@property
def restore(self) -> RestoreEntrypoint:
"""The mode's entrypoint for restoring a model."""
return restore_export
129 changes: 122 additions & 7 deletions modelopt/torch/nas/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Main APIs+entrypoints for model pruning."""
"""Main APIs+entrypoints for NAS conversion and export."""

from torch import nn

from modelopt.torch.opt.conversion import apply_mode
from modelopt.torch.opt.mode import ModeLike, _ModeRegistryCls
from modelopt.torch.utils import ModelLike, unwrap_model

__all__ = ["convert", "export"]
from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField
from modelopt.torch.opt.conversion import ApplyModeError, apply_mode
from modelopt.torch.opt.mode import (
ConvertEntrypoint,
ConvertReturnType,
MetadataDict,
ModeDescriptor,
ModeLike,
RestoreEntrypoint,
_ModeRegistryCls,
)
from modelopt.torch.utils import ModelLike, compare_dict, unwrap_model

from .patch import PatchManager
from .search_space import SearchSpace
from .utils import get_subnet_config, select

__all__ = ["ExportConfig", "ExportNASModeDescriptor", "convert", "export"]

NASModeRegistry = _ModeRegistryCls("nas")

Expand Down Expand Up @@ -89,6 +102,108 @@ def convert(
return apply_mode(model, mode, registry=registry)


class ExportConfig(ModeloptBaseConfig):
"""Configuration for the export mode.

This mode is used to export a model after NAS search.
"""

strict: bool = ModeloptField(
default=True,
title="Strict export",
description="Enforces that the subnet configuration must exactly match during export.",
)

calib: bool = ModeloptField(
default=False,
title="Calibration",
description="Whether to calibrate the subnet before exporting.",
)


def export_searchspace(model: nn.Module, config: ExportConfig) -> ConvertReturnType:
"""Export a subnet configuration of the search space to a regular model."""
# sanity check to avoid DP/DDP here in the entrypoint
model = unwrap_model(model, raise_error=True)

# store config from model if we can find it for a future convert/restore process
subnet_config = get_subnet_config(model)

# Check for patching and calibration
if PatchManager.is_patched(model):
manager = PatchManager.get_manager(model)
if config.calib:
manager.call_post_eval()
manager.unpatch()

# export model in-place
model = SearchSpace(model).export()

# construct metadata
metadata = {
"subnet_config": subnet_config,
}

return model, metadata


def restore_export(model: nn.Module, config: ExportConfig, metadata: MetadataDict) -> nn.Module:
"""Restore & export the subnet configuration of the search space to a regular model."""
# Megatron save_sharded_modelopt_state does not save subnet_config
if "subnet_config" not in metadata:
return model

# select subnet config provided in metadata
select(model, metadata["subnet_config"], strict=config["strict"])

# run export
model, metadata_new = export_searchspace(model, config)

# double check metadata
unmatched_keys = compare_dict(metadata, metadata_new)
if unmatched_keys:
raise ApplyModeError(f"Unmatched metadata={unmatched_keys}!")

return model


@NASModeRegistry.register_mode
class ExportNASModeDescriptor(ModeDescriptor):
"""Class to describe the ``"export_nas"`` mode.

The properties of this mode can be inspected via the source code.
"""

@property
def name(self) -> str:
"""Returns the value (str representation) of the mode."""
return "export_nas"

@property
def config_class(self) -> type[ModeloptBaseConfig]:
"""Specifies the config class for the mode."""
return ExportConfig

@property
def is_export_mode(self) -> bool:
"""Whether the mode is an export mode.

Returns:
True if the mode is an export mode, False otherwise. Defaults to False.
"""
return True

@property
def convert(self) -> ConvertEntrypoint:
"""The mode's entrypoint for converting a model."""
return export_searchspace

@property
def restore(self) -> RestoreEntrypoint:
"""The mode's entrypoint for restoring a model."""
return restore_export


def export(model: nn.Module, strict: bool = True, calib: bool = False) -> nn.Module:
"""Export a pruned subnet to a regular model.

Expand Down Expand Up @@ -118,4 +233,4 @@ def export(model: nn.Module, strict: bool = True, calib: bool = False) -> nn.Mod

# apply export mode and return model
config = {"strict": strict, "calib": calib}
return apply_mode(model, [("export", config)], registry=NASModeRegistry)
return apply_mode(model, [("export_nas", config)], registry=NASModeRegistry)
2 changes: 1 addition & 1 deletion modelopt/torch/nas/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@
__all__ = ["DMRegistry"]


DMRegistry = _DMRegistryCls(prefix="Dynamic") # global instance for the registry
DMRegistry = _DMRegistryCls(prefix="Dynamic") # global instance for the NAS registry
6 changes: 0 additions & 6 deletions modelopt/torch/nas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,6 @@
"replace_forward",
]

# we have two different numbers here since during training it might take longer to stabilize
MODELOPT_QUEUE_MAXLEN = 50 # indicates length of modelopt data queue for BN calib
MODELOPT_BN_CALIB_ITERS = (
100 # indicates # iters in train mode 'til we trust BN stats without calib
)


@contextmanager
def batch_norm_ignored_flops():
Expand Down
4 changes: 4 additions & 0 deletions modelopt/torch/prune/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@

# nas is a required - so let's check if it's available
import modelopt.torch.nas
from modelopt.torch.utils import import_plugin

from . import fastnas, gradnas, plugins
from .pruning import *

with import_plugin("mcore_minitron", verbose=False):
from .plugins import mcore_minitron
Loading