Skip to content

Commit 64574e1

Browse files
authored
fix hunyuan v1 dense (ml-explore#440)
1 parent 1b08ef1 commit 64574e1

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

mlx_lm/models/hunyuan_v1_dense.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,8 @@ def __init__(self, args: ModelArgs):
209209
self.args = args
210210
self.model_type = args.model_type
211211
self.model = HunyuanV1DenseModel(args)
212+
if not args.tie_word_embeddings:
213+
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
212214

213215
def __call__(
214216
self,
@@ -217,8 +219,16 @@ def __call__(
217219
cache=None,
218220
):
219221
out = self.model(inputs, mask, cache)
220-
return self.model.embed_tokens.as_linear(out)
222+
if self.args.tie_word_embeddings:
223+
return self.model.embed_tokens.as_linear(out)
224+
else:
225+
return self.lm_head(out)
221226

222227
@property
223228
def layers(self):
224229
return self.model.layers
230+
231+
def sanitize(self, weights):
232+
if self.args.tie_word_embeddings:
233+
weights.pop("lm_head.weight", None)
234+
return weights

0 commit comments

Comments
 (0)