5959}
6060
6161
62- def get_supported_models ():
63- """Get the supported models for Minitron pruning.
64-
65- NOTE: Keep inside function to avoid circular import issues.
66- """
67- supported_models = set ()
68-
69- try :
70- from megatron .core .models .gpt import GPTModel
71-
72- supported_models .add (GPTModel )
73- except Exception :
74- pass
75-
76- try :
77- from megatron .core .models .mamba import MambaModel
78-
79- supported_models .add (MambaModel )
80- except Exception :
81- pass
82-
83- try :
84- from nemo .collections import llm
85-
86- # NOTE: llm.MambaModel is a subclass of llm.GPTModel
87- supported_models .add (llm .GPTModel )
88- except Exception :
89- pass
90-
91- return supported_models
92-
93-
9462class MCoreMinitronSearcher (BaseSearcher ):
9563 """Searcher for Minitron pruning algorithm."""
9664
@@ -158,17 +126,6 @@ def before_search(self) -> None:
158126 def run_search (self ) -> None :
159127 """Run actual search."""
160128 # Run forward loop to collect activations and sort parameters
161- model_cfg = None
162- supported_models = get_supported_models ()
163- for m_type in supported_models :
164- if isinstance (self .model , m_type ):
165- model_cfg = self .model .config
166- break
167- if model_cfg is None :
168- raise NotImplementedError (
169- f"Only { supported_models } models are supported! Got: { type (self .model )} "
170- )
171-
172129 assert self .forward_loop is not None
173130 is_training = self .model .training
174131 self .model .eval ()
@@ -187,6 +144,7 @@ def run_search(self) -> None:
187144 hp .active = export_config [hp_name ]
188145
189146 # kv_channels can be None so we need to save original from original hidden_size and num_attention_heads
147+ model_cfg = self .model .config
190148 orig_kv_channels = getattr (model_cfg , "kv_channels" )
191149 if orig_kv_channels is None :
192150 orig_kv_channels = getattr (model_cfg , "hidden_size" ) // getattr (
0 commit comments