Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 1 addition & 43 deletions modelopt/torch/prune/plugins/mcore_minitron.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,38 +59,6 @@
}


def get_supported_models():
"""Get the supported models for Minitron pruning.

NOTE: Keep inside function to avoid circular import issues.
"""
supported_models = set()

try:
from megatron.core.models.gpt import GPTModel

supported_models.add(GPTModel)
except Exception:
pass

try:
from megatron.core.models.mamba import MambaModel

supported_models.add(MambaModel)
except Exception:
pass

try:
from nemo.collections import llm

# NOTE: llm.MambaModel is a subclass of llm.GPTModel
supported_models.add(llm.GPTModel)
except Exception:
pass

return supported_models


class MCoreMinitronSearcher(BaseSearcher):
"""Searcher for Minitron pruning algorithm."""

Expand Down Expand Up @@ -158,17 +126,6 @@ def before_search(self) -> None:
def run_search(self) -> None:
"""Run actual search."""
# Run forward loop to collect activations and sort parameters
model_cfg = None
supported_models = get_supported_models()
for m_type in supported_models:
if isinstance(self.model, m_type):
model_cfg = self.model.config
break
if model_cfg is None:
raise NotImplementedError(
f"Only {supported_models} models are supported! Got: {type(self.model)}"
)

assert self.forward_loop is not None
is_training = self.model.training
self.model.eval()
Expand All @@ -187,6 +144,7 @@ def run_search(self) -> None:
hp.active = export_config[hp_name]

# kv_channels can be None so we need to save original from original hidden_size and num_attention_heads
model_cfg = self.model.config
orig_kv_channels = getattr(model_cfg, "kv_channels")
if orig_kv_channels is None:
orig_kv_channels = getattr(model_cfg, "hidden_size") // getattr(
Expand Down
Loading