Skip to content

Commit 289d2ed

Browse files
[TRTLLM-8946][feat] Improved heuristics to detect shardable regions (NVIDIA#9200)
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Co-authored-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
1 parent fa9bee7 commit 289d2ed

File tree

5 files changed

+466
-381
lines changed

5 files changed

+466
-381
lines changed

tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
import torch.nn.functional as F
99
from einops import rearrange
10-
from transformers import AutoConfig, AutoModelForCausalLM
10+
from transformers import AutoModelForCausalLM
1111

1212
from tensorrt_llm._torch.auto_deploy.models.patches.bamba import _bamba_mixer_torch_forward
1313

@@ -190,41 +190,41 @@ def get_model_from_config_patched(config, **kwargs):
190190
# TODO: figure out how this can be incorporated into the export patch system
191191
AutoModelForCausalLM.from_config = get_model_from_config_patched
192192

193-
_config_from_pretrained_original = AutoConfig.from_pretrained
194-
_nemotron_h_base_model_tp_plan = {
195-
# mamba SSM layer
196-
"in_proj": "mamba",
197-
"out_proj": "rowwise",
198-
# attention layer
199-
"q_proj": "colwise",
200-
"k_proj": "colwise",
201-
"v_proj": "colwise",
202-
"o_proj": "rowwise",
203-
# NOTE: consider not sharding shared experts and/or
204-
# latent projections at all, keeping them replicated.
205-
# To do so, comment out the corresponding entries.
206-
# moe layer: SHARED experts
207-
"up_proj": "colwise",
208-
"down_proj": "rowwise",
209-
# MoLE: latent projections: simple shard
210-
"fc1_latent_proj": "gather",
211-
"fc2_latent_proj": "gather",
212-
}
213-
214-
215-
def get_config_from_pretrained_patched(*args, **kwargs):
216-
ret = _config_from_pretrained_original(*args, **kwargs)
217-
config = ret[0] if isinstance(ret, tuple) else ret
218-
# heuristic to check if it's a NemotronH MoE Model
219-
model_type = getattr(config, "model_type", None)
220-
num_moe_layers = getattr(config, "layers_block_type", []).count("moe")
221-
if model_type == "nemotron_h" and num_moe_layers > 0:
222-
config.base_model_tp_plan = _nemotron_h_base_model_tp_plan
223-
return (config, *ret[1:]) if isinstance(ret, tuple) else config
224-
225-
226-
# TODO: figure out how this can be incorporated into the export patch system
227-
AutoConfig.from_pretrained = get_config_from_pretrained_patched
193+
# _config_from_pretrained_original = AutoConfig.from_pretrained
194+
# _nemotron_h_base_model_tp_plan = {
195+
# # mamba SSM layer
196+
# "in_proj": "mamba",
197+
# "out_proj": "rowwise",
198+
# # attention layer
199+
# "q_proj": "colwise",
200+
# "k_proj": "colwise",
201+
# "v_proj": "colwise",
202+
# "o_proj": "rowwise",
203+
# # NOTE: consider not sharding shared experts and/or
204+
# # latent projections at all, keeping them replicated.
205+
# # To do so, comment out the corresponding entries.
206+
# # moe layer: SHARED experts
207+
# "up_proj": "colwise",
208+
# "down_proj": "rowwise",
209+
# # MoLE: latent projections: simple shard
210+
# "fc1_latent_proj": "gather",
211+
# "fc2_latent_proj": "gather",
212+
# }
213+
214+
215+
# def get_config_from_pretrained_patched(*args, **kwargs):
216+
# ret = _config_from_pretrained_original(*args, **kwargs)
217+
# config = ret[0] if isinstance(ret, tuple) else ret
218+
# # heuristic to check if it's a NemotronH MoE Model
219+
# model_type = getattr(config, "model_type", None)
220+
# num_moe_layers = getattr(config, "layers_block_type", []).count("moe")
221+
# if model_type == "nemotron_h" and num_moe_layers > 0:
222+
# config.base_model_tp_plan = _nemotron_h_base_model_tp_plan
223+
# return (config, *ret[1:]) if isinstance(ret, tuple) else config
224+
225+
226+
# # TODO: figure out how this can be incorporated into the export patch system
227+
# AutoConfig.from_pretrained = get_config_from_pretrained_patched
228228

229229
# TODO: figure out how this can be incorporated into the export patch system
230230
# Only patch if the module isn't available

0 commit comments

Comments
 (0)