From 08cb3f779c62c9ba61ceb5010bdcba909f76ccdf Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Tue, 12 Nov 2024 14:56:47 -0800 Subject: [PATCH] share embeddding and output (#6800) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/6800 Differential Revision: D64189995 --- examples/models/llama/llama_transformer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 20b8b1e30d4..263ef97dffe 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -459,7 +459,6 @@ def __init__(self, params: ModelArgs): for layer_id in range(params.n_layers): self.layers.append(TransformerBlock(layer_id, params)) self.norm = RMSNorm(params.dim, eps=params.norm_eps) - self.output = nn.Linear(params.dim, params.vocab_size, bias=False) self.use_kv_cache = params.use_kv_cache self.generate_full_logits = params.generate_full_logits self.max_seq_len = params.max_seq_len @@ -540,7 +539,7 @@ def forward( h = self.norm(h) - logits = self.output(h) + logits = torch.nn.functional.linear(h, self.tok_embeddings.weight) if self.output_prune_map is not None: # expand to original size so that downstream applications can use the logits as-is.