Skip to content

Commit f3116e5

Browse files
load hook fix
Signed-off-by: greg-kwasniewski1 <[email protected]>
1 parent a8df449 commit f3116e5

File tree

2 files changed

+25
-13
lines changed

2 files changed

+25
-13
lines changed

tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_h.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,6 @@ def __init__(self, config, layer_idx: int):
128128
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
129129
self.use_bias = config.use_bias
130130

131-
self.register_load_state_dict_pre_hook(self._load_state_dict_pre_hook)
132-
133131
def torch_forward(self, input_states):
134132
batch_size, seq_len, _ = input_states.shape
135133
dtype = input_states.dtype
@@ -191,14 +189,6 @@ def torch_forward(self, input_states):
191189
def forward(self, hidden_states):
192190
return self.torch_forward(hidden_states)
193191

194-
@staticmethod
195-
def _load_state_dict_pre_hook(module, state_dict, prefix, local_metadata, strict,
196-
missing_keys, unexpected_keys, error_msgs) -> None:
197-
A_log_key = prefix + "A_log"
198-
A_minus_key = prefix + "A_minus"
199-
if A_log_key in state_dict:
200-
state_dict[A_minus_key] = -torch.exp(state_dict.pop(A_log_key).float())
201-
202192

203193
class NemotronHRMSNorm(nn.Module):
204194
def __init__(self, hidden_size, eps=1e-6):
@@ -592,6 +582,13 @@ def __init__(self, config):
592582
self.backbone = NemotronHModel(config)
593583
self.vocab_size = config.vocab_size
594584
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
585+
# Recursively iterate over all modules in self.backbone and list those with A_minus or A_log in their name
586+
self.backbone_modules_with_A = []
587+
for module_name, module in self.backbone.named_modules():
588+
for param_name, _ in module.named_parameters(recurse=False):
589+
if param_name in ("A_minus", "A_log"):
590+
self.register_load_state_dict_pre_hook(self._a_log_pre_hook)
591+
self.backbone_modules_with_A.append((module_name, param_name))
595592

596593
# Initialize weights and apply final processing
597594
self.post_init()
@@ -622,5 +619,23 @@ def forward(
622619

623620
return NemotronHCausalLMOutput(logits)
624621

622+
@staticmethod
623+
def _a_log_pre_hook(
624+
module,
625+
state_dict,
626+
prefix,
627+
local_metadata,
628+
strict,
629+
missing_keys,
630+
unexpected_keys,
631+
error_msgs,
632+
) -> None:
633+
all_keys = list(state_dict.keys())
634+
for key in all_keys:
635+
if "A_log" in key:
636+
A_log_key = key
637+
A_minus_key = key.replace("A_log", "A_minus")
638+
state_dict[A_minus_key] = -torch.exp(state_dict.pop(A_log_key).float())
639+
625640

626641
AutoModelForCausalLMFactory.register_custom_model_cls("NemotronHConfig", NemotronHForCausalLM)

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -896,13 +896,10 @@ def _load_hook(
896896
# This is quite a hacky solution. A better solution would be to store extra_state in
897897
# the state_dict to identify whether the state_dict is sharded or not.
898898
key = prefix + param_key
899-
ad_logger.debug(f"Sharder LOAD hook is called for '{key}'")
900899
if key not in state_dict:
901900
return
902901
p_to_load = state_dict[key]
903-
904902
p_to_load = p_to_load if param_shape == p_to_load.shape else f_split(p_to_load)
905-
906903
state_dict[key] = p_to_load
907904

908905

0 commit comments

Comments
 (0)