From 62ede709eaaee2c3c674da9010b824548aa5c433 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=98=A4=EC=9C=A4=EC=A7=84=20Yoonjin=20Oh?= Date: Sat, 16 Dec 2023 10:51:56 +0900 Subject: [PATCH] Fix: Variable name --- lit_llama/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lit_llama/utils.py b/lit_llama/utils.py index a09ada20..bc224576 100644 --- a/lit_llama/utils.py +++ b/lit_llama/utils.py @@ -28,7 +28,7 @@ def llama_model_lookup(checkpoint: dict) -> str: Checks the width of the lm_head.weight matrix, as these uniquely identify the model. """ - embedding_size = checkpoint['transformer.wte.weight'].shape[1] + embedding_size = checkpoint['lm_head.weight'].shape[1] return llama_model_sizes[embedding_size]