diff --git a/examples/megatron-lm/README.md b/examples/megatron-lm/README.md index d706508fa..755cd016e 100644 --- a/examples/megatron-lm/README.md +++ b/examples/megatron-lm/README.md @@ -110,6 +110,8 @@ Coming soon ... ### ⭐ Pruning +Checkout pruning [getting started section](../pruning/README.md#getting-started) and [guidelines](../pruning/README.md#pruning-guidelines) for configuring pruning parameters in the pruning README. + Pruning is supported for GPT and Mamba models in Pipeline Parallel mode. Available pruning options are: - `TARGET_FFN_HIDDEN_SIZE` @@ -121,14 +123,20 @@ Pruning is supported for GPT and Mamba models in Pipeline Parallel mode. Availab - `TARGET_NUM_LAYERS` - `LAYERS_TO_DROP` (comma separated, 1-indexed list of layer numbers to directly drop) +Example for depth pruning Qwen3-8B from 36 to 24 layers: + ```sh PP=1 \ TARGET_NUM_LAYERS=24 \ HF_MODEL_CKPT= \ -MLM_MODEL_SAVE=/tmp/Qwen3-8B-DPruned \ +MLM_MODEL_SAVE=Qwen3-8B-Pruned \ bash megatron-lm/examples/post_training/modelopt/prune.sh qwen/Qwen3-8B ``` +> [!TIP] +> If number of layers in the model is not divisible by pipeline parallel size (PP), you can configure uneven +> PP by setting `MLM_EXTRA_ARGS="--decoder-first-pipeline-num-layers --decoder-last-pipeline-num-layers "` + ## Learn More About Configuration For simplicity, we use `shell` scripts and variables as arguments. Each script has at least 1 positional diff --git a/modelopt/torch/nas/autonas.py b/modelopt/torch/nas/autonas.py index ed5c3f6be..00393657c 100644 --- a/modelopt/torch/nas/autonas.py +++ b/modelopt/torch/nas/autonas.py @@ -30,11 +30,7 @@ from pydantic import create_model from torch.nn.modules.batchnorm import _BatchNorm -from modelopt.torch.opt.config import ( - ModeloptBaseConfig, - ModeloptField, - get_kwargs_for_create_model_with_rules, -) +from modelopt.torch.opt.config import ModeloptBaseConfig, get_kwargs_for_create_model_with_rules from modelopt.torch.opt.conversion import ApplyModeError, ModelLikeModule from modelopt.torch.opt.mode import ( ConvertEntrypoint, @@ -56,34 +52,35 @@ stats, torch_detach, torch_to, - unwrap_model, ) from .algorithms import ConstraintsFunc, get_constraints_func from .conversion import NASModeRegistry from .patch import PatchData, PatchManager, _modelopt_eval_recursion_guard, prep_for_eval from .registry import DMRegistry -from .search_space import SearchSpace, generate_search_space -from .utils import MODELOPT_BN_CALIB_ITERS, MODELOPT_QUEUE_MAXLEN, get_subnet_config, sample, select +from .search_space import generate_search_space +from .utils import get_subnet_config, sample, select __all__ = [ "AutoNASConfig", "AutoNASModeDescriptor", "AutoNASPatchManager", "EvolveSearcher", - "ExportConfig", - "ExportModeDescriptor", "IterativeSearcher", "RandomSearcher", "convert_autonas_searchspace", "convert_searchspace", - "export_searchspace", "restore_autonas_searchspace", - "restore_export", "restore_searchspace", "update_autonas_metadata", ] +# we have two different numbers here since during training it might take longer to stabilize +MODELOPT_QUEUE_MAXLEN = 50 # indicates length of modelopt data queue for BN calib +MODELOPT_BN_CALIB_ITERS = ( + 100 # indicates # iters in train mode 'til we trust BN stats without calib +) + def _get_ratio_list(): return (0.5, 0.67, 1.0) @@ -132,25 +129,6 @@ def _norm_lin_config(): ) -class ExportConfig(ModeloptBaseConfig): - """Configuration for the export mode. - - This mode is used to export a model after NAS search. - """ - - strict: bool = ModeloptField( - default=True, - title="Strict export", - description="Enforces that the subnet configuration must exactly match during export.", - ) - - calib: bool = ModeloptField( - default=False, - title="Calibration", - description="Whether to calibrate the subnet before exporting.", - ) - - class AutoNASPatchManager(PatchManager): """A class to handle the monkey patching of the model for automode.""" @@ -676,48 +654,6 @@ def update_autonas_metadata( metadata["subnet_config"] = get_subnet_config(model) -def export_searchspace(model: nn.Module, config: ExportConfig) -> ConvertReturnType: - """Export a subnet configuration of the search space to a regular model.""" - # sanity check to avoid DP/DDP here in the entrypoint - model = unwrap_model(model, raise_error=True) - - # store config from model if we can find it for a future convert/restore process - subnet_config = get_subnet_config(model) - - # Check for patching and calibration - if PatchManager.is_patched(model): - manager = PatchManager.get_manager(model) - if config.calib: - manager.call_post_eval() - manager.unpatch() - - # export model in-place - model = SearchSpace(model).export() - - # construct metadata - metadata = { - "subnet_config": subnet_config, - } - - return model, metadata - - -def restore_export(model: nn.Module, config: ExportConfig, metadata: MetadataDict) -> nn.Module: - """Restore & export the subnet configuration of the search space to a regular model.""" - # select subnet config provided in metadata - select(model, metadata["subnet_config"], strict=config["strict"]) - - # run export - model, metadata_new = export_searchspace(model, config) - - # double check metadata - unmatched_keys = compare_dict(metadata, metadata_new) - if unmatched_keys: - raise ApplyModeError(f"Unmatched metadata={unmatched_keys}!") - - return model - - @NASModeRegistry.register_mode class AutoNASModeDescriptor(ModeDescriptor): """Class to describe the ``"autonas"`` mode. @@ -738,12 +674,12 @@ def config_class(self) -> type[ModeloptBaseConfig]: @property def next_modes(self) -> set[str] | None: """Modes that must immediately follow this mode.""" - return {"export", "kd_loss", "quantize", "sparse_magnitude", "sparse_gpt"} + return {"export_nas", "kd_loss", "quantize", "sparse_magnitude", "sparse_gpt"} @property def export_mode(self) -> str | None: """The mode that corresponds to the export mode of this mode.""" - return "export" + return "export_nas" @property def search_algorithm(self) -> type[BaseSearcher]: @@ -769,40 +705,3 @@ def update_for_save(self) -> UpdateEntrypoint: def update_for_new_mode(self) -> UpdateEntrypoint: """The mode's entrypoint for updating the models state before new mode.""" return update_autonas_metadata - - -@NASModeRegistry.register_mode -class ExportModeDescriptor(ModeDescriptor): - """Class to describe the ``"export"`` mode. - - The properties of this mode can be inspected via the source code. - """ - - @property - def name(self) -> str: - """Returns the value (str representation) of the mode.""" - return "export" - - @property - def config_class(self) -> type[ModeloptBaseConfig]: - """Specifies the config class for the mode.""" - return ExportConfig - - @property - def is_export_mode(self) -> bool: - """Whether the mode is an export mode. - - Returns: - True if the mode is an export mode, False otherwise. Defaults to False. - """ - return True - - @property - def convert(self) -> ConvertEntrypoint: - """The mode's entrypoint for converting a model.""" - return export_searchspace - - @property - def restore(self) -> RestoreEntrypoint: - """The mode's entrypoint for restoring a model.""" - return restore_export diff --git a/modelopt/torch/nas/conversion.py b/modelopt/torch/nas/conversion.py index a93c209fb..f7cb52652 100644 --- a/modelopt/torch/nas/conversion.py +++ b/modelopt/torch/nas/conversion.py @@ -13,15 +13,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Main APIs+entrypoints for model pruning.""" +"""Main APIs+entrypoints for NAS conversion and export.""" from torch import nn -from modelopt.torch.opt.conversion import apply_mode -from modelopt.torch.opt.mode import ModeLike, _ModeRegistryCls -from modelopt.torch.utils import ModelLike, unwrap_model - -__all__ = ["convert", "export"] +from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField +from modelopt.torch.opt.conversion import ApplyModeError, apply_mode +from modelopt.torch.opt.mode import ( + ConvertEntrypoint, + ConvertReturnType, + MetadataDict, + ModeDescriptor, + ModeLike, + RestoreEntrypoint, + _ModeRegistryCls, +) +from modelopt.torch.utils import ModelLike, compare_dict, unwrap_model + +from .patch import PatchManager +from .search_space import SearchSpace +from .utils import get_subnet_config, select + +__all__ = ["ExportConfig", "ExportNASModeDescriptor", "convert", "export"] NASModeRegistry = _ModeRegistryCls("nas") @@ -89,6 +102,108 @@ def convert( return apply_mode(model, mode, registry=registry) +class ExportConfig(ModeloptBaseConfig): + """Configuration for the export mode. + + This mode is used to export a model after NAS search. + """ + + strict: bool = ModeloptField( + default=True, + title="Strict export", + description="Enforces that the subnet configuration must exactly match during export.", + ) + + calib: bool = ModeloptField( + default=False, + title="Calibration", + description="Whether to calibrate the subnet before exporting.", + ) + + +def export_searchspace(model: nn.Module, config: ExportConfig) -> ConvertReturnType: + """Export a subnet configuration of the search space to a regular model.""" + # sanity check to avoid DP/DDP here in the entrypoint + model = unwrap_model(model, raise_error=True) + + # store config from model if we can find it for a future convert/restore process + subnet_config = get_subnet_config(model) + + # Check for patching and calibration + if PatchManager.is_patched(model): + manager = PatchManager.get_manager(model) + if config.calib: + manager.call_post_eval() + manager.unpatch() + + # export model in-place + model = SearchSpace(model).export() + + # construct metadata + metadata = { + "subnet_config": subnet_config, + } + + return model, metadata + + +def restore_export(model: nn.Module, config: ExportConfig, metadata: MetadataDict) -> nn.Module: + """Restore & export the subnet configuration of the search space to a regular model.""" + # Megatron save_sharded_modelopt_state does not save subnet_config + if "subnet_config" not in metadata: + return model + + # select subnet config provided in metadata + select(model, metadata["subnet_config"], strict=config["strict"]) + + # run export + model, metadata_new = export_searchspace(model, config) + + # double check metadata + unmatched_keys = compare_dict(metadata, metadata_new) + if unmatched_keys: + raise ApplyModeError(f"Unmatched metadata={unmatched_keys}!") + + return model + + +@NASModeRegistry.register_mode +class ExportNASModeDescriptor(ModeDescriptor): + """Class to describe the ``"export_nas"`` mode. + + The properties of this mode can be inspected via the source code. + """ + + @property + def name(self) -> str: + """Returns the value (str representation) of the mode.""" + return "export_nas" + + @property + def config_class(self) -> type[ModeloptBaseConfig]: + """Specifies the config class for the mode.""" + return ExportConfig + + @property + def is_export_mode(self) -> bool: + """Whether the mode is an export mode. + + Returns: + True if the mode is an export mode, False otherwise. Defaults to False. + """ + return True + + @property + def convert(self) -> ConvertEntrypoint: + """The mode's entrypoint for converting a model.""" + return export_searchspace + + @property + def restore(self) -> RestoreEntrypoint: + """The mode's entrypoint for restoring a model.""" + return restore_export + + def export(model: nn.Module, strict: bool = True, calib: bool = False) -> nn.Module: """Export a pruned subnet to a regular model. @@ -118,4 +233,4 @@ def export(model: nn.Module, strict: bool = True, calib: bool = False) -> nn.Mod # apply export mode and return model config = {"strict": strict, "calib": calib} - return apply_mode(model, [("export", config)], registry=NASModeRegistry) + return apply_mode(model, [("export_nas", config)], registry=NASModeRegistry) diff --git a/modelopt/torch/nas/registry.py b/modelopt/torch/nas/registry.py index 37cee47f1..2fe0e9ce9 100644 --- a/modelopt/torch/nas/registry.py +++ b/modelopt/torch/nas/registry.py @@ -20,4 +20,4 @@ __all__ = ["DMRegistry"] -DMRegistry = _DMRegistryCls(prefix="Dynamic") # global instance for the registry +DMRegistry = _DMRegistryCls(prefix="Dynamic") # global instance for the NAS registry diff --git a/modelopt/torch/nas/utils.py b/modelopt/torch/nas/utils.py index 51cb6456e..fb5faa368 100644 --- a/modelopt/torch/nas/utils.py +++ b/modelopt/torch/nas/utils.py @@ -61,12 +61,6 @@ "replace_forward", ] -# we have two different numbers here since during training it might take longer to stabilize -MODELOPT_QUEUE_MAXLEN = 50 # indicates length of modelopt data queue for BN calib -MODELOPT_BN_CALIB_ITERS = ( - 100 # indicates # iters in train mode 'til we trust BN stats without calib -) - @contextmanager def batch_norm_ignored_flops(): diff --git a/modelopt/torch/prune/__init__.py b/modelopt/torch/prune/__init__.py index a81ae3a9f..aac5f7e87 100644 --- a/modelopt/torch/prune/__init__.py +++ b/modelopt/torch/prune/__init__.py @@ -21,6 +21,10 @@ # nas is a required - so let's check if it's available import modelopt.torch.nas +from modelopt.torch.utils import import_plugin from . import fastnas, gradnas, plugins from .pruning import * + +with import_plugin("mcore_minitron", verbose=False): + from .plugins import mcore_minitron diff --git a/modelopt/torch/prune/fastnas.py b/modelopt/torch/prune/fastnas.py index 6e74b8856..4852efdad 100644 --- a/modelopt/torch/prune/fastnas.py +++ b/modelopt/torch/prune/fastnas.py @@ -343,12 +343,12 @@ def config_class(self) -> type[ModeloptBaseConfig]: @property def next_modes(self) -> set[str] | None: """Modes that must immediately follow this mode.""" - return {"export", "kd_loss", "quantize", "sparse_magnitude", "sparse_gpt"} + return {"export_nas", "kd_loss", "quantize", "sparse_magnitude", "sparse_gpt"} @property def export_mode(self) -> str | None: """The mode that corresponds to the export mode of this mode.""" - return "export" + return "export_nas" @property def search_algorithm(self) -> type[BaseSearcher]: diff --git a/modelopt/torch/prune/plugins/mcore_minitron.py b/modelopt/torch/prune/plugins/mcore_minitron.py index 6d3dfe6eb..5f94c3175 100644 --- a/modelopt/torch/prune/plugins/mcore_minitron.py +++ b/modelopt/torch/prune/plugins/mcore_minitron.py @@ -37,6 +37,7 @@ HAS_MAMBA, _DynamicMCoreLanguageModel, SUPPORTED_MODELS, + drop_mcore_language_model_layers, ) # isort: on @@ -70,7 +71,13 @@ "num_layers", } -__all__ = ["MCoreMinitronConfig", "MCoreMinitronModeDescriptor", "MCoreMinitronSearcher"] +__all__ = [ + "SUPPORTED_HPARAMS", + "MCoreMinitronConfig", + "MCoreMinitronModeDescriptor", + "MCoreMinitronSearcher", + "drop_mcore_language_model_layers", +] class MCoreMinitronSearcher(BaseSearcher): @@ -267,8 +274,8 @@ def convert_mcore_minitron(model: nn.Module, config: ModeloptBaseConfig) -> Conv def restore_mcore_minitron( model: nn.Module, config: ModeloptBaseConfig, metadata: dict ) -> nn.Module: - """Restore the model to the original state.""" - return convert_mcore_minitron(model, config)[0] + """Restore the model (no-op since we don't want to convert again which forces TP=1).""" + return model @NASModeRegistry.register_mode @@ -292,12 +299,12 @@ def config_class(self) -> type[ModeloptBaseConfig]: @property def next_modes(self) -> set[str] | None: """Modes that must immediately follow this mode.""" - return {"export", "kd_loss", "quantize", "sparse_magnitude", "sparse_gpt"} + return {"export_nas", "kd_loss", "quantize", "sparse_magnitude", "sparse_gpt"} @property def export_mode(self) -> str | None: """The mode that corresponds to the export mode of this mode.""" - return "export" + return "export_nas" @property def search_algorithm(self) -> type[BaseSearcher]: diff --git a/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py b/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py index c53841e4b..b80d5ef7e 100644 --- a/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py +++ b/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py @@ -24,6 +24,7 @@ from _test_utils.torch_dist.dist_utils import spawn_multiprocess_job from _test_utils.torch_dist.plugins.megatron_common import ( get_mcore_gpt_model, + run_mcore_inference, run_mcore_inference_with_dummy_input, ) @@ -88,6 +89,7 @@ def _get_model(initialize_megatron=True): return model model = _get_model() + sd = model.state_dict() def forward_loop(m): for _ in range(5): @@ -154,19 +156,24 @@ def forward_loop(m): assert model.config.num_layers == pruned_num_layers # Assert forward pass works on the pruned model - run_mcore_inference_with_dummy_input(model, batch_size, pruned_hidden_size) + prompt_tokens = torch.randint(0, vocab_size, (batch_size, max_sequence_length)).cuda() + output = run_mcore_inference(model, prompt_tokens, pruned_hidden_size) # Assert re-pruning from scores_path works without running the forward loop again if ckpt_path: - model = _get_model(initialize_megatron=False) + model_rerun = _get_model(initialize_megatron=False) + model_rerun.load_state_dict(sd) mtp.prune( - model, + model_rerun, mode="mcore_minitron", constraints={"export_config": export_config}, dummy_input=None, # Not used config={"scores_path": ckpt_path}, ) + output_rerun = run_mcore_inference(model_rerun, prompt_tokens, pruned_hidden_size) + assert torch.allclose(output, output_rerun, atol=1e-5) + @pytest.mark.parametrize( ( diff --git a/tests/unit/torch/nas/plugins/test_hf_nas_save_restore.py b/tests/unit/torch/nas/plugins/test_hf_nas_save_restore.py index d1270f36e..e1026dc4c 100644 --- a/tests/unit/torch/nas/plugins/test_hf_nas_save_restore.py +++ b/tests/unit/torch/nas/plugins/test_hf_nas_save_restore.py @@ -31,7 +31,7 @@ def test_pruned_transformers_save_restore(tmp_path): model_ref = BertForQuestionAnswering.from_pretrained(tiny_bert_dir) # Export a random subnet (proxy for search / prune) - model_ref = apply_mode_with_sampling(model_ref, ["fastnas", "export"]) + model_ref = apply_mode_with_sampling(model_ref, ["fastnas", "export_nas"]) model_ref.save_pretrained(tiny_bert_dir / "modelopt_model") assert os.path.exists(tiny_bert_dir / "modelopt_model/modelopt_state.pth") diff --git a/tests/unit/torch/nas/test_nas.py b/tests/unit/torch/nas/test_nas.py index a389b1622..2de8f8e3b 100644 --- a/tests/unit/torch/nas/test_nas.py +++ b/tests/unit/torch/nas/test_nas.py @@ -378,7 +378,7 @@ def _decorated_compare(): "", False, torch.randn(1, 16, 8, 8), - ["autonas", "export"], + ["autonas", "export_nas"], ), ( InvertedResidual, @@ -386,7 +386,7 @@ def _decorated_compare(): "conv.1", True, torch.randn(1, 16, 8, 8), - ["autonas", "export"], + ["autonas", "export_nas"], ), ( InvertedResidual, @@ -410,9 +410,16 @@ def _decorated_compare(): "", True, torch.randn(1, 16, 8, 8), - ["autonas", "export"], + ["autonas", "export_nas"], + ), + ( + TinyMobileNetFeatures, + (), + "", + False, + torch.randn(1, 3, 64, 64), + ["autonas", "export_nas"], ), - (TinyMobileNetFeatures, (), "", False, torch.randn(1, 3, 64, 64), ["autonas", "export"]), (TinyMobileNetFeatures, (), "", False, torch.randn(1, 3, 64, 64), ["autonas"]), (TinyMobileNetFeatures, (), "", False, torch.randn(1, 3, 64, 64), []), ], @@ -424,9 +431,9 @@ def test_save_restore_whole( # setup model model = cls(*args) - # check for "export" - if "export" in mode: - mode.remove("export") + # check for "export_nas" + if "export_nas" in mode: + mode.remove("export_nas") use_export = True else: use_export = False diff --git a/tests/unit/torch/opt/test_chaining.py b/tests/unit/torch/opt/test_chaining.py index 1682292a8..39cad5a4c 100644 --- a/tests/unit/torch/opt/test_chaining.py +++ b/tests/unit/torch/opt/test_chaining.py @@ -47,20 +47,28 @@ def get_kd_mode(): "mode", [ ["autonas"], - ["autonas", "export"], - ["autonas", "export", "fastnas"], - ["autonas", "export", "fastnas", "export"], - ["autonas", "export", "fastnas", "export", get_kd_mode()], - ["autonas", "export", "fastnas", "export", get_kd_mode(), "export_student"], - ["autonas", "export", "fastnas", "export", "quantize", get_kd_mode(), "export_student"], - [get_kd_mode(), "export_student", "fastnas", "export", get_kd_mode(), "export_student"], + ["autonas", "export_nas"], + ["autonas", "export_nas", "fastnas"], + ["autonas", "export_nas", "fastnas", "export_nas"], + ["autonas", "export_nas", "fastnas", "export_nas", get_kd_mode()], + ["autonas", "export_nas", "fastnas", "export_nas", get_kd_mode(), "export_student"], + [ + "autonas", + "export_nas", + "fastnas", + "export_nas", + "quantize", + get_kd_mode(), + "export_student", + ], + [get_kd_mode(), "export_student", "fastnas", "export_nas", get_kd_mode(), "export_student"], ["quantize"], - ["fastnas", get_kd_mode(), "export_student", "export"], + ["fastnas", get_kd_mode(), "export_student", "export_nas"], ["sparse_magnitude", get_kd_mode(), "export_student", "export_sparse"], ["sparse_magnitude", "quantize", get_kd_mode(), "export_student"], - ["fastnas", "export", "sparse_magnitude", "quantize", get_kd_mode(), "export_student"], + ["fastnas", "export_nas", "sparse_magnitude", "quantize", get_kd_mode(), "export_student"], ["fastnas", "quantize", get_kd_mode(), "export_student"], - ["fastnas", "sparse_magnitude", "export_sparse", "export"], + ["fastnas", "sparse_magnitude", "export_sparse", "export_nas"], ], ) def test_chained_save_restore(mode): @@ -95,20 +103,20 @@ def test_chained_save_restore(mode): ("mode", "error_msg"), [ ( - ["export"], - [r"Cannot add export according to the current export stack: deque\(\[.*\]\)."], + ["export_nas"], + [r"Cannot add export_nas according to the current export stack: deque\(\[.*\]\)."], ), ( ["autonas", "fastnas"], [r"Cannot add fastnas after autonas! Next modes of autonas are \{.*\}."], ), ( - ["fastnas", "export", "export_student"], + ["fastnas", "export_nas", "export_student"], [r"Cannot add export_student according to the current export stack: deque\(\[.*\]\)."], ), ( - ["quantize", "export"], - [r"Cannot add export according to the current export stack: deque\(\[.*\]\)."], + ["quantize", "export_nas"], + [r"Cannot add export_nas according to the current export stack: deque\(\[.*\]\)."], ), ( ["quantize", "fastnas"], @@ -117,8 +125,8 @@ def test_chained_save_restore(mode): ], ), ( - ["fastnas", get_kd_mode(), "export", "export_student"], - [r"Cannot add export according to the current export stack: deque\(\[.*\]\)."], + ["fastnas", get_kd_mode(), "export_nas", "export_student"], + [r"Cannot add export_nas according to the current export stack: deque\(\[.*\]\)."], ), ], )