Skip to content

Commit 1cf78b2

Browse files
Remove pruning supported model list in code (#290)
Signed-off-by: Keval Morabia <[email protected]>
1 parent 8a07376 commit 1cf78b2

File tree

1 file changed

+1
-43
lines changed

1 file changed

+1
-43
lines changed

modelopt/torch/prune/plugins/mcore_minitron.py

Lines changed: 1 addition & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -59,38 +59,6 @@
5959
}
6060

6161

62-
def get_supported_models():
63-
"""Get the supported models for Minitron pruning.
64-
65-
NOTE: Keep inside function to avoid circular import issues.
66-
"""
67-
supported_models = set()
68-
69-
try:
70-
from megatron.core.models.gpt import GPTModel
71-
72-
supported_models.add(GPTModel)
73-
except Exception:
74-
pass
75-
76-
try:
77-
from megatron.core.models.mamba import MambaModel
78-
79-
supported_models.add(MambaModel)
80-
except Exception:
81-
pass
82-
83-
try:
84-
from nemo.collections import llm
85-
86-
# NOTE: llm.MambaModel is a subclass of llm.GPTModel
87-
supported_models.add(llm.GPTModel)
88-
except Exception:
89-
pass
90-
91-
return supported_models
92-
93-
9462
class MCoreMinitronSearcher(BaseSearcher):
9563
"""Searcher for Minitron pruning algorithm."""
9664

@@ -158,17 +126,6 @@ def before_search(self) -> None:
158126
def run_search(self) -> None:
159127
"""Run actual search."""
160128
# Run forward loop to collect activations and sort parameters
161-
model_cfg = None
162-
supported_models = get_supported_models()
163-
for m_type in supported_models:
164-
if isinstance(self.model, m_type):
165-
model_cfg = self.model.config
166-
break
167-
if model_cfg is None:
168-
raise NotImplementedError(
169-
f"Only {supported_models} models are supported! Got: {type(self.model)}"
170-
)
171-
172129
assert self.forward_loop is not None
173130
is_training = self.model.training
174131
self.model.eval()
@@ -187,6 +144,7 @@ def run_search(self) -> None:
187144
hp.active = export_config[hp_name]
188145

189146
# kv_channels can be None so we need to save original from original hidden_size and num_attention_heads
147+
model_cfg = self.model.config
190148
orig_kv_channels = getattr(model_cfg, "kv_channels")
191149
if orig_kv_channels is None:
192150
orig_kv_channels = getattr(model_cfg, "hidden_size") // getattr(

0 commit comments

Comments
 (0)