File tree Expand file tree Collapse file tree 1 file changed +11
-1
lines changed
Expand file tree Collapse file tree 1 file changed +11
-1
lines changed Original file line number Diff line number Diff line change @@ -209,6 +209,8 @@ def __init__(self, args: ModelArgs):
209209 self .args = args
210210 self .model_type = args .model_type
211211 self .model = HunyuanV1DenseModel (args )
212+ if not args .tie_word_embeddings :
213+ self .lm_head = nn .Linear (args .hidden_size , args .vocab_size , bias = False )
212214
213215 def __call__ (
214216 self ,
@@ -217,8 +219,16 @@ def __call__(
217219 cache = None ,
218220 ):
219221 out = self .model (inputs , mask , cache )
220- return self .model .embed_tokens .as_linear (out )
222+ if self .args .tie_word_embeddings :
223+ return self .model .embed_tokens .as_linear (out )
224+ else :
225+ return self .lm_head (out )
221226
222227 @property
223228 def layers (self ):
224229 return self .model .layers
230+
231+ def sanitize (self , weights ):
232+ if self .args .tie_word_embeddings :
233+ weights .pop ("lm_head.weight" , None )
234+ return weights
You can’t perform that action at this time.
0 commit comments