|
58 | 58 | "num_layers", |
59 | 59 | } |
60 | 60 |
|
61 | | -SUPPORTED_MODELS = set() |
62 | 61 |
|
63 | | -try: |
64 | | - from megatron.core.models.gpt import GPTModel |
| 62 | +def get_supported_models(): |
| 63 | + """Get the supported models for Minitron pruning. |
65 | 64 |
|
66 | | - SUPPORTED_MODELS.add(GPTModel) |
67 | | -except Exception: |
68 | | - pass |
| 65 | + NOTE: Keep inside function to avoid circular import issues. |
| 66 | + """ |
| 67 | + supported_models = set() |
| 68 | + |
| 69 | + try: |
| 70 | + from megatron.core.models.gpt import GPTModel |
| 71 | + |
| 72 | + supported_models.add(GPTModel) |
| 73 | + except Exception: |
| 74 | + pass |
| 75 | + |
| 76 | + try: |
| 77 | + from megatron.core.models.mamba import MambaModel |
69 | 78 |
|
70 | | -try: |
71 | | - from megatron.core.models.mamba import MambaModel |
| 79 | + supported_models.add(MambaModel) |
| 80 | + except Exception: |
| 81 | + pass |
72 | 82 |
|
73 | | - SUPPORTED_MODELS.add(MambaModel) |
74 | | -except Exception: |
75 | | - pass |
| 83 | + try: |
| 84 | + from nemo.collections import llm |
76 | 85 |
|
77 | | -try: |
78 | | - from nemo.collections import llm |
| 86 | + # NOTE: llm.MambaModel is a subclass of llm.GPTModel |
| 87 | + supported_models.add(llm.GPTModel) |
| 88 | + except Exception: |
| 89 | + pass |
79 | 90 |
|
80 | | - # NOTE: llm.MambaModel is a subclass of llm.GPTModel |
81 | | - SUPPORTED_MODELS.add(llm.GPTModel) |
82 | | -except Exception: |
83 | | - pass |
| 91 | + return supported_models |
84 | 92 |
|
85 | 93 |
|
86 | 94 | class MCoreMinitronSearcher(BaseSearcher): |
@@ -151,13 +159,14 @@ def run_search(self) -> None: |
151 | 159 | """Run actual search.""" |
152 | 160 | # Run forward loop to collect activations and sort parameters |
153 | 161 | model_cfg = None |
154 | | - for m_type in SUPPORTED_MODELS: |
| 162 | + supported_models = get_supported_models() |
| 163 | + for m_type in supported_models: |
155 | 164 | if isinstance(self.model, m_type): |
156 | 165 | model_cfg = self.model.config |
157 | 166 | break |
158 | 167 | if model_cfg is None: |
159 | 168 | raise NotImplementedError( |
160 | | - f"Only {SUPPORTED_MODELS} models are supported! Got: {type(self.model)}" |
| 169 | + f"Only {supported_models} models are supported! Got: {type(self.model)}" |
161 | 170 | ) |
162 | 171 |
|
163 | 172 | assert self.forward_loop is not None |
|
0 commit comments