Skip to content

Commit 4c96be0

Browse files
committed
Embeddings weights alias fix for latest falcon models
1 parent a094488 commit 4c96be0

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

server/text_generation_server/inference_engine/hf_custom_tp.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,10 @@ def __init__(
8585
elif model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
8686
if sharded and self._config.alibi:
8787
raise NotImplementedError("TP is not supported for Falcon models using alibi")
88-
aliases = {"transformer.word_embeddings.weight": ["lm_head.weight"]}
88+
aliases = {
89+
"transformer.word_embeddings.weight": ["lm_head.weight"],
90+
"lm_head.weight": ["transformer.word_embeddings.weight"],
91+
}
8992
from text_generation_server.models.custom_modeling.flash_rw_modeling import FlashRWForCausalLM
9093
model_class = FlashRWForCausalLM
9194

0 commit comments

Comments
 (0)