59
59
}
60
60
61
61
62
- def get_supported_models ():
63
- """Get the supported models for Minitron pruning.
64
-
65
- NOTE: Keep inside function to avoid circular import issues.
66
- """
67
- supported_models = set ()
68
-
69
- try :
70
- from megatron .core .models .gpt import GPTModel
71
-
72
- supported_models .add (GPTModel )
73
- except Exception :
74
- pass
75
-
76
- try :
77
- from megatron .core .models .mamba import MambaModel
78
-
79
- supported_models .add (MambaModel )
80
- except Exception :
81
- pass
82
-
83
- try :
84
- from nemo .collections import llm
85
-
86
- # NOTE: llm.MambaModel is a subclass of llm.GPTModel
87
- supported_models .add (llm .GPTModel )
88
- except Exception :
89
- pass
90
-
91
- return supported_models
92
-
93
-
94
62
class MCoreMinitronSearcher (BaseSearcher ):
95
63
"""Searcher for Minitron pruning algorithm."""
96
64
@@ -158,17 +126,6 @@ def before_search(self) -> None:
158
126
def run_search (self ) -> None :
159
127
"""Run actual search."""
160
128
# Run forward loop to collect activations and sort parameters
161
- model_cfg = None
162
- supported_models = get_supported_models ()
163
- for m_type in supported_models :
164
- if isinstance (self .model , m_type ):
165
- model_cfg = self .model .config
166
- break
167
- if model_cfg is None :
168
- raise NotImplementedError (
169
- f"Only { supported_models } models are supported! Got: { type (self .model )} "
170
- )
171
-
172
129
assert self .forward_loop is not None
173
130
is_training = self .model .training
174
131
self .model .eval ()
@@ -187,6 +144,7 @@ def run_search(self) -> None:
187
144
hp .active = export_config [hp_name ]
188
145
189
146
# kv_channels can be None so we need to save original from original hidden_size and num_attention_heads
147
+ model_cfg = self .model .config
190
148
orig_kv_channels = getattr (model_cfg , "kv_channels" )
191
149
if orig_kv_channels is None :
192
150
orig_kv_channels = getattr (model_cfg , "hidden_size" ) // getattr (
0 commit comments