Skip to content

Commit 1fc5bf5

Browse files
committed
support glm-4-9b-chat
Signed-off-by: XingXing Qiao <[email protected]>
1 parent f3bc337 commit 1fc5bf5

File tree

5 files changed

+116
-7
lines changed

5 files changed

+116
-7
lines changed

convert-hf-to-gguf.py

Lines changed: 93 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
476476
if chkhsh == "c136ed14d01c2745d4f60a9596ae66800e2b61fa45643e72436041855ad4089d":
477477
# ref: https://huggingface.co/abacusai/Smaug-Llama-3-70B-Instruct
478478
res = "smaug-bpe"
479+
if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b":
480+
# ref: https://huggingface.co/THUDM/glm-4-9b-chat
481+
res = "chatglm-bpe"
479482

480483
if res is None:
481484
logger.warning("\n")
@@ -2714,7 +2717,7 @@ def write_tensors(self):
27142717
class ChatGLMModel(Model):
27152718
model_arch = gguf.MODEL_ARCH.CHATGLM
27162719

2717-
def set_vocab(self):
2720+
def set_vocab_chatglm3(self):
27182721
dir_model = self.dir_model
27192722
hparams = self.hparams
27202723
tokens: list[bytearray] = []
@@ -2725,7 +2728,8 @@ def set_vocab(self):
27252728
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
27262729
vocab_size = hparams.get("padded_vocab_size", len(tokenizer.get_vocab()))
27272730
assert max(tokenizer.get_vocab().values()) < vocab_size
2728-
2731+
print(vocab_size)
2732+
print(max(tokenizer.get_vocab().values()))
27292733
for token_id in range(vocab_size):
27302734
piece = tokenizer._convert_id_to_token(token_id)
27312735
if token_id == 0:
@@ -2774,6 +2778,91 @@ def set_vocab(self):
27742778
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
27752779
special_vocab.add_to_gguf(self.gguf_writer)
27762780

2781+
@staticmethod
2782+
def token_bytes_to_string(b):
2783+
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
2784+
byte_encoder = bytes_to_unicode()
2785+
return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')])
2786+
2787+
@staticmethod
2788+
def bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: int | None = None) -> list[bytes]:
2789+
parts = [bytes([b]) for b in token]
2790+
while True:
2791+
min_idx = None
2792+
min_rank = None
2793+
for i, pair in enumerate(zip(parts[:-1], parts[1:])):
2794+
rank = mergeable_ranks.get(pair[0] + pair[1])
2795+
if rank is not None and (min_rank is None or rank < min_rank):
2796+
min_idx = i
2797+
min_rank = rank
2798+
if min_rank is None or (max_rank is not None and min_rank >= max_rank):
2799+
break
2800+
assert min_idx is not None
2801+
parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:]
2802+
return parts
2803+
2804+
def set_vocab(self):
2805+
if "THUDM/chatglm3-6b" in self.hparams.get("_name_or_path", ""):
2806+
self.set_vocab_chatglm3()
2807+
return
2808+
2809+
dir_model = self.dir_model
2810+
hparams = self.hparams
2811+
tokens: list[str] = []
2812+
toktypes: list[int] = []
2813+
2814+
from transformers import AutoTokenizer
2815+
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
2816+
vocab_size = hparams["padded_vocab_size"]
2817+
assert max(tokenizer.get_vocab().values()) < vocab_size
2818+
2819+
tokpre = self.get_vocab_base_pre(tokenizer)
2820+
2821+
merges = []
2822+
vocab = {}
2823+
mergeable_ranks = tokenizer.mergeable_ranks
2824+
for token, rank in mergeable_ranks.items():
2825+
vocab[ChatGLMModel.token_bytes_to_string(token)] = rank
2826+
if len(token) == 1:
2827+
continue
2828+
merged = ChatGLMModel.bpe(mergeable_ranks, token, max_rank=rank)
2829+
assert len(merged) >= 2 and len(merged) <= 7
2830+
merges.append(' '.join(map(ChatGLMModel.token_bytes_to_string, merged)))
2831+
2832+
# for this kind of tokenizer, added_vocab is not a subset of vocab, so they need to be combined
2833+
added_vocab = tokenizer.get_added_vocab()
2834+
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **added_vocab}.items()}
2835+
2836+
for i in range(vocab_size):
2837+
if i not in reverse_vocab:
2838+
tokens.append(f"[PAD{i}]")
2839+
toktypes.append(gguf.TokenType.USER_DEFINED)
2840+
elif reverse_vocab[i] in added_vocab:
2841+
tokens.append(reverse_vocab[i])
2842+
if tokenizer.added_tokens_decoder[i].special:
2843+
toktypes.append(gguf.TokenType.CONTROL)
2844+
else:
2845+
toktypes.append(gguf.TokenType.USER_DEFINED)
2846+
else:
2847+
tokens.append(reverse_vocab[i])
2848+
toktypes.append(gguf.TokenType.NORMAL)
2849+
2850+
self.gguf_writer.add_tokenizer_model("gpt2")
2851+
self.gguf_writer.add_tokenizer_pre(tokpre)
2852+
self.gguf_writer.add_token_list(tokens)
2853+
self.gguf_writer.add_token_types(toktypes)
2854+
2855+
special_vocab = gguf.SpecialVocab(dir_model, load_merges=False)
2856+
special_vocab.chat_template = "ChatGLM4"
2857+
special_vocab.merges = merges
2858+
# only add special tokens when they were not already loaded from config.json
2859+
if len(special_vocab.special_token_ids) == 0:
2860+
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"])
2861+
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"])
2862+
# this one is usually not in config.json anyway
2863+
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"])
2864+
special_vocab.add_to_gguf(self.gguf_writer)
2865+
27772866
def set_gguf_parameters(self):
27782867
self.gguf_writer.add_name(self.dir_model.name)
27792868
n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
@@ -2934,7 +3023,8 @@ def main() -> None:
29343023
with torch.inference_mode():
29353024
model_class = Model.from_model_architecture(hparams["architectures"][0])
29363025
model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file, args.no_lazy)
2937-
3026+
print(model_class)
3027+
print(model_instance)
29383028
logger.info("Set model parameters")
29393029
model_instance.set_gguf_parameters()
29403030

examples/server/server.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3056,7 +3056,7 @@ int main(int argc, char ** argv) {
30563056
chat.push_back({{"role", "user"}, {"content", "Hello"}});
30573057
chat.push_back({{"role", "assistant"}, {"content", "Hi there"}});
30583058
chat.push_back({{"role", "user"}, {"content", "How are you?"}});
3059-
3059+
printf("sparams.chat_template: #%s#\n", sparams.chat_template.c_str());
30603060
const std::string chat_example = format_chat(ctx_server.model, sparams.chat_template, chat);
30613061

30623062
LOG_INFO("chat template", {

llama.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4508,6 +4508,7 @@ static void llm_load_hparams(
45084508
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
45094509
switch (hparams.n_layer) {
45104510
case 28: model.type = e_model::MODEL_7B; break;
4511+
case 40: model.type = e_model::MODEL_8B; break;
45114512
default: model.type = e_model::MODEL_UNKNOWN;
45124513
}
45134514
} break;
@@ -4636,9 +4637,9 @@ static void llm_load_vocab(
46364637
if (merges_keyidx == -1) {
46374638
throw std::runtime_error("cannot find tokenizer merges in model file\n");
46384639
}
4639-
4640+
printf("merges_keyidx: %d\n", merges_keyidx);
46404641
const int n_merges = gguf_get_arr_n(ctx, merges_keyidx);
4641-
4642+
printf("n_merges: %d\n", n_merges);
46424643
for (int i = 0; i < n_merges; i++) {
46434644
const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i);
46444645
GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0);
@@ -4728,6 +4729,9 @@ static void llm_load_vocab(
47284729
} else if (
47294730
tokenizer_pre == "smaug-bpe") {
47304731
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMAUG;
4732+
} else if (
4733+
tokenizer_pre == "chatglm-bpe") {
4734+
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHATGLM4;
47314735
} else {
47324736
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
47334737
}
@@ -11449,7 +11453,7 @@ struct llm_build_context {
1144911453
cb(Qcur, "Qcur", il);
1145011454
cb(Kcur, "Kcur", il);
1145111455
cb(Vcur, "Vcur", il);
11452-
11456+
//printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor);
1145311457
Qcur = ggml_rope_ext(
1145411458
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
1145511459
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
@@ -13032,6 +13036,7 @@ struct llm_tokenizer_bpe {
1303213036
break;
1303313037
case LLAMA_VOCAB_PRE_TYPE_DBRX:
1303413038
case LLAMA_VOCAB_PRE_TYPE_SMAUG:
13039+
case LLAMA_VOCAB_PRE_TYPE_CHATGLM4:
1303513040
word_collection = unicode_regex_split(text, {
1303613041
// same as llama3
1303713042
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
@@ -18741,6 +18746,15 @@ static int32_t llama_chat_apply_template_internal(
1874118746
if (add_ass) {
1874218747
ss << "<|assistant|>";
1874318748
}
18749+
} else if (tmpl == "ChatGLM4") {
18750+
ss << "[gMASK]" << "<sop>";
18751+
for (auto message : chat) {
18752+
std::string role(message->role);
18753+
ss << "<|" << role << "|>" << "\n" << message->content;
18754+
}
18755+
if (add_ass) {
18756+
ss << "<|assistant|>";
18757+
}
1874418758
} else {
1874518759
// template not supported
1874618760
return -1;

llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ extern "C" {
8686
LLAMA_VOCAB_PRE_TYPE_OLMO = 12,
8787
LLAMA_VOCAB_PRE_TYPE_DBRX = 13,
8888
LLAMA_VOCAB_PRE_TYPE_SMAUG = 14,
89+
LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 15,
8990
};
9091

9192
// note: these values should be synchronized with ggml_rope

tests/test-chat-template.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ int main(void) {
5959
"{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}",
6060
// ChatGLM3
6161
"{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
62+
// ChatGLM4
63+
"ChatGLM4",
6264
};
6365
std::vector<std::string> expected_output = {
6466
// teknium/OpenHermes-2.5-Mistral-7B
@@ -97,6 +99,8 @@ int main(void) {
9799
"<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
98100
// ChatGLM3
99101
"[gMASK]sop<|system|>\n You are a helpful assistant<|user|>\n Hello<|assistant|>\n Hi there<|user|>\n Who are you<|assistant|>\n I am an assistant <|user|>\n Another question<|assistant|>",
102+
// ChatGLM4
103+
"[gMASK]<sop><|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>",
100104
};
101105
std::vector<char> formatted_chat(1024);
102106
int32_t res;

0 commit comments

Comments
 (0)