diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 20b8b1e30d4..263ef97dffe 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -459,7 +459,6 @@ def __init__(self, params: ModelArgs): for layer_id in range(params.n_layers): self.layers.append(TransformerBlock(layer_id, params)) self.norm = RMSNorm(params.dim, eps=params.norm_eps) - self.output = nn.Linear(params.dim, params.vocab_size, bias=False) self.use_kv_cache = params.use_kv_cache self.generate_full_logits = params.generate_full_logits self.max_seq_len = params.max_seq_len @@ -540,7 +539,7 @@ def forward( h = self.norm(h) - logits = self.output(h) + logits = torch.nn.functional.linear(h, self.tok_embeddings.weight) if self.output_prune_map is not None: # expand to original size so that downstream applications can use the logits as-is.