Skip to content

Commit 08cb3f7

Browse files
cccclaifacebook-github-bot
authored andcommitted
share embeddding and output (#6800)
Summary: Pull Request resolved: #6800 Differential Revision: D64189995
1 parent 4b7a60f commit 08cb3f7

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

examples/models/llama/llama_transformer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,6 @@ def __init__(self, params: ModelArgs):
459459
for layer_id in range(params.n_layers):
460460
self.layers.append(TransformerBlock(layer_id, params))
461461
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
462-
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
463462
self.use_kv_cache = params.use_kv_cache
464463
self.generate_full_logits = params.generate_full_logits
465464
self.max_seq_len = params.max_seq_len
@@ -540,7 +539,7 @@ def forward(
540539

541540
h = self.norm(h)
542541

543-
logits = self.output(h)
542+
logits = torch.nn.functional.linear(h, self.tok_embeddings.weight)
544543

545544
if self.output_prune_map is not None:
546545
# expand to original size so that downstream applications can use the logits as-is.

0 commit comments

Comments
 (0)