File tree Expand file tree Collapse file tree 1 file changed +1
-2
lines changed
Expand file tree Collapse file tree 1 file changed +1
-2
lines changed Original file line number Diff line number Diff line change @@ -459,7 +459,6 @@ def __init__(self, params: ModelArgs):
459459 for layer_id in range (params .n_layers ):
460460 self .layers .append (TransformerBlock (layer_id , params ))
461461 self .norm = RMSNorm (params .dim , eps = params .norm_eps )
462- self .output = nn .Linear (params .dim , params .vocab_size , bias = False )
463462 self .use_kv_cache = params .use_kv_cache
464463 self .generate_full_logits = params .generate_full_logits
465464 self .max_seq_len = params .max_seq_len
@@ -540,7 +539,7 @@ def forward(
540539
541540 h = self .norm (h )
542541
543- logits = self . output ( h )
542+ logits = torch . nn . functional . linear ( h , self . tok_embeddings . weight )
544543
545544 if self .output_prune_map is not None :
546545 # expand to original size so that downstream applications can use the logits as-is.
You can’t perform that action at this time.
0 commit comments