Skip to content

Commit 90b61ed

Browse files
Move pruning nemo import in function to avoid circular import
Signed-off-by: Keval Morabia <[email protected]>
1 parent c359cb7 commit 90b61ed

File tree

1 file changed

+28
-19
lines changed

1 file changed

+28
-19
lines changed

modelopt/torch/prune/plugins/mcore_minitron.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -58,29 +58,37 @@
5858
"num_layers",
5959
}
6060

61-
SUPPORTED_MODELS = set()
6261

63-
try:
64-
from megatron.core.models.gpt import GPTModel
62+
def get_supported_models():
63+
"""Get the supported models for Minitron pruning.
6564
66-
SUPPORTED_MODELS.add(GPTModel)
67-
except Exception:
68-
pass
65+
NOTE: Keep inside function to avoid circular import issues.
66+
"""
67+
supported_models = set()
68+
69+
try:
70+
from megatron.core.models.gpt import GPTModel
71+
72+
supported_models.add(GPTModel)
73+
except Exception:
74+
pass
75+
76+
try:
77+
from megatron.core.models.mamba import MambaModel
6978

70-
try:
71-
from megatron.core.models.mamba import MambaModel
79+
supported_models.add(MambaModel)
80+
except Exception:
81+
pass
7282

73-
SUPPORTED_MODELS.add(MambaModel)
74-
except Exception:
75-
pass
83+
try:
84+
from nemo.collections import llm
7685

77-
try:
78-
from nemo.collections import llm
86+
# NOTE: llm.MambaModel is a subclass of llm.GPTModel
87+
supported_models.add(llm.GPTModel)
88+
except Exception:
89+
pass
7990

80-
# NOTE: llm.MambaModel is a subclass of llm.GPTModel
81-
SUPPORTED_MODELS.add(llm.GPTModel)
82-
except Exception:
83-
pass
91+
return supported_models
8492

8593

8694
class MCoreMinitronSearcher(BaseSearcher):
@@ -151,13 +159,14 @@ def run_search(self) -> None:
151159
"""Run actual search."""
152160
# Run forward loop to collect activations and sort parameters
153161
model_cfg = None
154-
for m_type in SUPPORTED_MODELS:
162+
supported_models = get_supported_models()
163+
for m_type in supported_models:
155164
if isinstance(self.model, m_type):
156165
model_cfg = self.model.config
157166
break
158167
if model_cfg is None:
159168
raise NotImplementedError(
160-
f"Only {SUPPORTED_MODELS} models are supported! Got: {type(self.model)}"
169+
f"Only {supported_models} models are supported! Got: {type(self.model)}"
161170
)
162171

163172
assert self.forward_loop is not None

0 commit comments

Comments
 (0)