Skip to content

Commit 85c797e

Browse files
authored
add support for Chinese-LLaMA-Alpaca (#2700)
1 parent af8d877 commit 85c797e

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

fastchat/model/model_adapter.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1681,6 +1681,31 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
16811681
return get_conv_template("llama2-chinese")
16821682

16831683

1684+
class Lamma2ChineseAlpacaAdapter(BaseModelAdapter):
1685+
"""The model adapter for ymcui/Chinese-LLaMA-Alpaca sft"""
1686+
1687+
def match(self, model_path: str):
1688+
return "chinese-alpaca" in model_path.lower()
1689+
1690+
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
1691+
revision = from_pretrained_kwargs.get("revision", "main")
1692+
tokenizer = AutoTokenizer.from_pretrained(
1693+
model_path,
1694+
trust_remote_code=True,
1695+
revision=revision,
1696+
)
1697+
model = AutoModelForCausalLM.from_pretrained(
1698+
model_path,
1699+
trust_remote_code=True,
1700+
low_cpu_mem_usage=True,
1701+
**from_pretrained_kwargs,
1702+
)
1703+
return model, tokenizer
1704+
1705+
def get_default_conv_template(self, model_path: str) -> Conversation:
1706+
return get_conv_template("llama2-chinese")
1707+
1708+
16841709
class VigogneAdapter(BaseModelAdapter):
16851710
"""The model adapter for vigogne (e.g., bofenghuang/vigogne-2-7b-chat)"""
16861711

@@ -1895,6 +1920,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
18951920
register_model_adapter(BGEAdapter)
18961921
register_model_adapter(E5Adapter)
18971922
register_model_adapter(Lamma2ChineseAdapter)
1923+
register_model_adapter(Lamma2ChineseAlpacaAdapter)
18981924
register_model_adapter(VigogneAdapter)
18991925
register_model_adapter(OpenLLaMaOpenInstructAdapter)
19001926
register_model_adapter(ReaLMAdapter)

fastchat/model/model_registry.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,12 @@ def get_model_info(name: str) -> ModelInfo:
330330
"https://huggingface.co/FlagAlpha/Llama2-Chinese-13b-Chat",
331331
"Llama2-Chinese is a multi-language large-scale language model (LLM), developed by FlagAlpha.",
332332
)
333+
register_model_info(
334+
["Chinese-Alpaca-2-7B", "Chinese-Alpaca-2-13B"],
335+
"Chinese-Alpaca",
336+
"https://huggingface.co/hfl/chinese-alpaca-2-13b",
337+
"New extended Chinese vocabulary beyond Llama-2, open-sourcing the Chinese LLaMA-2 and Alpaca-2 LLMs.",
338+
)
333339
register_model_info(
334340
["Vigogne-2-7B-Instruct", "Vigogne-2-13B-Instruct"],
335341
"Vigogne-Instruct",

0 commit comments

Comments
 (0)