@@ -128,8 +128,6 @@ def __init__(self, config, layer_idx: int):
128128 self .out_proj = nn .Linear (self .intermediate_size , self .hidden_size , bias = config .use_bias )
129129 self .use_bias = config .use_bias
130130
131- self .register_load_state_dict_pre_hook (self ._load_state_dict_pre_hook )
132-
133131 def torch_forward (self , input_states ):
134132 batch_size , seq_len , _ = input_states .shape
135133 dtype = input_states .dtype
@@ -191,14 +189,6 @@ def torch_forward(self, input_states):
191189 def forward (self , hidden_states ):
192190 return self .torch_forward (hidden_states )
193191
194- @staticmethod
195- def _load_state_dict_pre_hook (module , state_dict , prefix , local_metadata , strict ,
196- missing_keys , unexpected_keys , error_msgs ) -> None :
197- A_log_key = prefix + "A_log"
198- A_minus_key = prefix + "A_minus"
199- if A_log_key in state_dict :
200- state_dict [A_minus_key ] = - torch .exp (state_dict .pop (A_log_key ).float ())
201-
202192
203193class NemotronHRMSNorm (nn .Module ):
204194 def __init__ (self , hidden_size , eps = 1e-6 ):
@@ -592,6 +582,13 @@ def __init__(self, config):
592582 self .backbone = NemotronHModel (config )
593583 self .vocab_size = config .vocab_size
594584 self .lm_head = nn .Linear (config .hidden_size , config .vocab_size , bias = False )
585+ # Recursively iterate over all modules in self.backbone and list those with A_minus or A_log in their name
586+ self .backbone_modules_with_A = []
587+ for module_name , module in self .backbone .named_modules ():
588+ for param_name , _ in module .named_parameters (recurse = False ):
589+ if param_name in ("A_minus" , "A_log" ):
590+ self .register_load_state_dict_pre_hook (self ._a_log_pre_hook )
591+ self .backbone_modules_with_A .append ((module_name , param_name ))
595592
596593 # Initialize weights and apply final processing
597594 self .post_init ()
@@ -622,5 +619,23 @@ def forward(
622619
623620 return NemotronHCausalLMOutput (logits )
624621
622+ @staticmethod
623+ def _a_log_pre_hook (
624+ module ,
625+ state_dict ,
626+ prefix ,
627+ local_metadata ,
628+ strict ,
629+ missing_keys ,
630+ unexpected_keys ,
631+ error_msgs ,
632+ ) -> None :
633+ all_keys = list (state_dict .keys ())
634+ for key in all_keys :
635+ if "A_log" in key :
636+ A_log_key = key
637+ A_minus_key = key .replace ("A_log" , "A_minus" )
638+ state_dict [A_minus_key ] = - torch .exp (state_dict .pop (A_log_key ).float ())
639+
625640
626641AutoModelForCausalLMFactory .register_custom_model_cls ("NemotronHConfig" , NemotronHForCausalLM )
0 commit comments