|
58 | 58 | "num_layers",
|
59 | 59 | }
|
60 | 60 |
|
61 |
| -SUPPORTED_MODELS = set() |
62 | 61 |
|
63 |
| -try: |
64 |
| - from megatron.core.models.gpt import GPTModel |
| 62 | +def get_supported_models(): |
| 63 | + """Get the supported models for Minitron pruning. |
65 | 64 |
|
66 |
| - SUPPORTED_MODELS.add(GPTModel) |
67 |
| -except Exception: |
68 |
| - pass |
| 65 | + NOTE: Keep inside function to avoid circular import issues. |
| 66 | + """ |
| 67 | + supported_models = set() |
| 68 | + |
| 69 | + try: |
| 70 | + from megatron.core.models.gpt import GPTModel |
| 71 | + |
| 72 | + supported_models.add(GPTModel) |
| 73 | + except Exception: |
| 74 | + pass |
| 75 | + |
| 76 | + try: |
| 77 | + from megatron.core.models.mamba import MambaModel |
69 | 78 |
|
70 |
| -try: |
71 |
| - from megatron.core.models.mamba import MambaModel |
| 79 | + supported_models.add(MambaModel) |
| 80 | + except Exception: |
| 81 | + pass |
72 | 82 |
|
73 |
| - SUPPORTED_MODELS.add(MambaModel) |
74 |
| -except Exception: |
75 |
| - pass |
| 83 | + try: |
| 84 | + from nemo.collections import llm |
76 | 85 |
|
77 |
| -try: |
78 |
| - from nemo.collections import llm |
| 86 | + # NOTE: llm.MambaModel is a subclass of llm.GPTModel |
| 87 | + supported_models.add(llm.GPTModel) |
| 88 | + except Exception: |
| 89 | + pass |
79 | 90 |
|
80 |
| - # NOTE: llm.MambaModel is a subclass of llm.GPTModel |
81 |
| - SUPPORTED_MODELS.add(llm.GPTModel) |
82 |
| -except Exception: |
83 |
| - pass |
| 91 | + return supported_models |
84 | 92 |
|
85 | 93 |
|
86 | 94 | class MCoreMinitronSearcher(BaseSearcher):
|
@@ -151,13 +159,14 @@ def run_search(self) -> None:
|
151 | 159 | """Run actual search."""
|
152 | 160 | # Run forward loop to collect activations and sort parameters
|
153 | 161 | model_cfg = None
|
154 |
| - for m_type in SUPPORTED_MODELS: |
| 162 | + supported_models = get_supported_models() |
| 163 | + for m_type in supported_models: |
155 | 164 | if isinstance(self.model, m_type):
|
156 | 165 | model_cfg = self.model.config
|
157 | 166 | break
|
158 | 167 | if model_cfg is None:
|
159 | 168 | raise NotImplementedError(
|
160 |
| - f"Only {SUPPORTED_MODELS} models are supported! Got: {type(self.model)}" |
| 169 | + f"Only {supported_models} models are supported! Got: {type(self.model)}" |
161 | 170 | )
|
162 | 171 |
|
163 | 172 | assert self.forward_loop is not None
|
|
0 commit comments