|
7 | 7 | import torch |
8 | 8 | import torch.nn.functional as F |
9 | 9 | from einops import rearrange |
10 | | -from transformers import AutoConfig, AutoModelForCausalLM |
| 10 | +from transformers import AutoModelForCausalLM |
11 | 11 |
|
12 | 12 | from tensorrt_llm._torch.auto_deploy.models.patches.bamba import _bamba_mixer_torch_forward |
13 | 13 |
|
@@ -190,41 +190,41 @@ def get_model_from_config_patched(config, **kwargs): |
190 | 190 | # TODO: figure out how this can be incorporated into the export patch system |
191 | 191 | AutoModelForCausalLM.from_config = get_model_from_config_patched |
192 | 192 |
|
193 | | -_config_from_pretrained_original = AutoConfig.from_pretrained |
194 | | -_nemotron_h_base_model_tp_plan = { |
195 | | - # mamba SSM layer |
196 | | - "in_proj": "mamba", |
197 | | - "out_proj": "rowwise", |
198 | | - # attention layer |
199 | | - "q_proj": "colwise", |
200 | | - "k_proj": "colwise", |
201 | | - "v_proj": "colwise", |
202 | | - "o_proj": "rowwise", |
203 | | - # NOTE: consider not sharding shared experts and/or |
204 | | - # latent projections at all, keeping them replicated. |
205 | | - # To do so, comment out the corresponding entries. |
206 | | - # moe layer: SHARED experts |
207 | | - "up_proj": "colwise", |
208 | | - "down_proj": "rowwise", |
209 | | - # MoLE: latent projections: simple shard |
210 | | - "fc1_latent_proj": "gather", |
211 | | - "fc2_latent_proj": "gather", |
212 | | -} |
213 | | - |
214 | | - |
215 | | -def get_config_from_pretrained_patched(*args, **kwargs): |
216 | | - ret = _config_from_pretrained_original(*args, **kwargs) |
217 | | - config = ret[0] if isinstance(ret, tuple) else ret |
218 | | - # heuristic to check if it's a NemotronH MoE Model |
219 | | - model_type = getattr(config, "model_type", None) |
220 | | - num_moe_layers = getattr(config, "layers_block_type", []).count("moe") |
221 | | - if model_type == "nemotron_h" and num_moe_layers > 0: |
222 | | - config.base_model_tp_plan = _nemotron_h_base_model_tp_plan |
223 | | - return (config, *ret[1:]) if isinstance(ret, tuple) else config |
224 | | - |
225 | | - |
226 | | -# TODO: figure out how this can be incorporated into the export patch system |
227 | | -AutoConfig.from_pretrained = get_config_from_pretrained_patched |
| 193 | +# _config_from_pretrained_original = AutoConfig.from_pretrained |
| 194 | +# _nemotron_h_base_model_tp_plan = { |
| 195 | +# # mamba SSM layer |
| 196 | +# "in_proj": "mamba", |
| 197 | +# "out_proj": "rowwise", |
| 198 | +# # attention layer |
| 199 | +# "q_proj": "colwise", |
| 200 | +# "k_proj": "colwise", |
| 201 | +# "v_proj": "colwise", |
| 202 | +# "o_proj": "rowwise", |
| 203 | +# # NOTE: consider not sharding shared experts and/or |
| 204 | +# # latent projections at all, keeping them replicated. |
| 205 | +# # To do so, comment out the corresponding entries. |
| 206 | +# # moe layer: SHARED experts |
| 207 | +# "up_proj": "colwise", |
| 208 | +# "down_proj": "rowwise", |
| 209 | +# # MoLE: latent projections: simple shard |
| 210 | +# "fc1_latent_proj": "gather", |
| 211 | +# "fc2_latent_proj": "gather", |
| 212 | +# } |
| 213 | + |
| 214 | + |
| 215 | +# def get_config_from_pretrained_patched(*args, **kwargs): |
| 216 | +# ret = _config_from_pretrained_original(*args, **kwargs) |
| 217 | +# config = ret[0] if isinstance(ret, tuple) else ret |
| 218 | +# # heuristic to check if it's a NemotronH MoE Model |
| 219 | +# model_type = getattr(config, "model_type", None) |
| 220 | +# num_moe_layers = getattr(config, "layers_block_type", []).count("moe") |
| 221 | +# if model_type == "nemotron_h" and num_moe_layers > 0: |
| 222 | +# config.base_model_tp_plan = _nemotron_h_base_model_tp_plan |
| 223 | +# return (config, *ret[1:]) if isinstance(ret, tuple) else config |
| 224 | + |
| 225 | + |
| 226 | +# # TODO: figure out how this can be incorporated into the export patch system |
| 227 | +# AutoConfig.from_pretrained = get_config_from_pretrained_patched |
228 | 228 |
|
229 | 229 | # TODO: figure out how this can be incorporated into the export patch system |
230 | 230 | # Only patch if the module isn't available |
|
0 commit comments