diff --git a/modelopt/torch/prune/plugins/mcore_minitron.py b/modelopt/torch/prune/plugins/mcore_minitron.py index abba65bc9..2fd4b439a 100644 --- a/modelopt/torch/prune/plugins/mcore_minitron.py +++ b/modelopt/torch/prune/plugins/mcore_minitron.py @@ -59,38 +59,6 @@ } -def get_supported_models(): - """Get the supported models for Minitron pruning. - - NOTE: Keep inside function to avoid circular import issues. - """ - supported_models = set() - - try: - from megatron.core.models.gpt import GPTModel - - supported_models.add(GPTModel) - except Exception: - pass - - try: - from megatron.core.models.mamba import MambaModel - - supported_models.add(MambaModel) - except Exception: - pass - - try: - from nemo.collections import llm - - # NOTE: llm.MambaModel is a subclass of llm.GPTModel - supported_models.add(llm.GPTModel) - except Exception: - pass - - return supported_models - - class MCoreMinitronSearcher(BaseSearcher): """Searcher for Minitron pruning algorithm.""" @@ -158,17 +126,6 @@ def before_search(self) -> None: def run_search(self) -> None: """Run actual search.""" # Run forward loop to collect activations and sort parameters - model_cfg = None - supported_models = get_supported_models() - for m_type in supported_models: - if isinstance(self.model, m_type): - model_cfg = self.model.config - break - if model_cfg is None: - raise NotImplementedError( - f"Only {supported_models} models are supported! Got: {type(self.model)}" - ) - assert self.forward_loop is not None is_training = self.model.training self.model.eval() @@ -187,6 +144,7 @@ def run_search(self) -> None: hp.active = export_config[hp_name] # kv_channels can be None so we need to save original from original hidden_size and num_attention_heads + model_cfg = self.model.config orig_kv_channels = getattr(model_cfg, "kv_channels") if orig_kv_channels is None: orig_kv_channels = getattr(model_cfg, "hidden_size") // getattr(