Skip to content

Commit 6b88093

Browse files
committed
fixed according to reviewer's comments
1 parent 276df3f commit 6b88093

File tree

9 files changed

+144
-191
lines changed

9 files changed

+144
-191
lines changed

convert_hf_to_gguf.py

Lines changed: 9 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,28 +1270,6 @@ def _set_vocab_llama_hf(self):
12701270
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
12711271
special_vocab.add_to_gguf(self.gguf_writer)
12721272

1273-
def _set_vocab_pangu_embedded(self):
1274-
tokens, scores, toktypes = self._create_vocab_sentencepiece()
1275-
1276-
self.gguf_writer.add_tokenizer_model("pangu_embedded")
1277-
self.gguf_writer.add_tokenizer_pre("default")
1278-
self.gguf_writer.add_token_list(tokens)
1279-
self.gguf_writer.add_token_scores(scores)
1280-
self.gguf_writer.add_token_types(toktypes)
1281-
1282-
tokenizer_config_file = self.dir_model / "tokenizer_config.json"
1283-
if tokenizer_config_file.is_file():
1284-
with open(tokenizer_config_file, "r", encoding="utf-8") as f:
1285-
tokenizer_config_json = json.load(f)
1286-
if "chat_template" in tokenizer_config_json:
1287-
self.gguf_writer.add_chat_template(tokenizer_config_json["chat_template"])
1288-
if "add_prefix_space" in tokenizer_config_json:
1289-
self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"])
1290-
1291-
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
1292-
special_vocab.add_to_gguf(self.gguf_writer)
1293-
1294-
12951273
def _set_vocab_rwkv_world(self):
12961274
assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file()
12971275
vocab_size = self.hparams.get("vocab_size", 65536)
@@ -7212,12 +7190,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
72127190
class PanguEmbeddedModel(TextModel):
72137191
model_arch = gguf.MODEL_ARCH.PANGU_EMBED
72147192

7215-
def set_vocab(self):
7216-
try:
7217-
self._set_vocab_pangu_embedded()
7218-
except FileNotFoundError:
7219-
print("pangu vocab set fail, fallback to sentencepiece!")
7220-
self._set_vocab_sentencepiece()
7193+
def set_vocab(self):
7194+
self._set_vocab_sentencepiece()
72217195

72227196
tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
72237197
if tokenizer_config_file.is_file():
@@ -7236,18 +7210,15 @@ def set_gguf_parameters(self):
72367210
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
72377211
self.gguf_writer.add_rope_dimension_count(rope_dim)
72387212

7239-
if (head_dim := hparams.get("head_dim")) is None:
7240-
if "hidden_size" in hparams and "num_attention_heads" in hparams:
7241-
head_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
7242-
7243-
if head_dim is not None:
7244-
self.gguf_writer.add_key_length(head_dim)
7245-
self.gguf_writer.add_value_length(head_dim)
7213+
if hparams.get("head_dim") is None:
7214+
self.gguf_writer.add_key_length(rope_dim)
7215+
self.gguf_writer.add_value_length(rope_dim)
72467216

72477217
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
7248-
del bid
7249-
n_head = self.find_hparam(["n_heads", "num_attention_heads"])
7250-
n_kv_head = self.find_hparam(["n_kv_heads", "num_key_value_heads"])
7218+
if name == "lm_head.weight":
7219+
if self.hparams.get("tie_word_embeddings", False):
7220+
logger.info("Skipping tied output layer 'lm_head.weight'")
7221+
return []
72517222
return [(self.map_tensor_name(name), data_torch)]
72527223

72537224

src/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ add_library(llama
8989
models/mamba.cpp
9090
models/minicpm3.cpp
9191
models/minimax-m2.cpp
92-
models/pangu_embedded.cpp
9392
models/mpt.cpp
9493
models/nemotron-h.cpp
9594
models/nemotron.cpp
@@ -100,6 +99,7 @@ add_library(llama
10099
models/openai-moe-iswa.cpp
101100
models/openelm.cpp
102101
models/orion.cpp
102+
models/pangu-embedded.cpp
103103
models/phi2.cpp
104104
models/phi3.cpp
105105
models/plamo.cpp

src/llama-arch.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
107107
{ LLM_ARCH_APERTUS, "apertus" },
108108
{ LLM_ARCH_MINIMAX_M2, "minimax-m2" },
109109
{ LLM_ARCH_COGVLM, "cogvlm" },
110-
{LLM_ARCH_PANGU_EMBED, "pangu_embedded" },
110+
{ LLM_ARCH_PANGU_EMBED, "pangu_embedded" },
111111
{ LLM_ARCH_UNKNOWN, "(unknown)" },
112112
};
113113

src/llama-chat.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
214214
return LLM_CHAT_TEMPLATE_SEED_OSS;
215215
} else if (tmpl_contains("'Assistant: ' + message['content'] + '<|separator|>")) {
216216
return LLM_CHAT_TEMPLATE_GROK_2;
217-
} else if (tmpl_contains("[unused9]") && tmpl_contains("[unused10]")) {
217+
} else if (tmpl_contains("[unused9]") && tmpl_contains("message['content'] + '[unused10]'")) {
218218
return LLM_CHAT_TEMPLATE_PANGU_EMBED;
219219
}
220220
return LLM_CHAT_TEMPLATE_UNKNOWN;
@@ -840,9 +840,6 @@ int32_t llm_chat_apply_template(
840840
ss << "[unused9]工具:" << content << "[unused10]";
841841
} else if (role == "function") {
842842
ss << "[unused9]方法:" << content << "[unused10]";
843-
} else {
844-
// unknown role
845-
ss << "[unused9]" << role << "" << content << "[unused10]";
846843
}
847844
}
848845
if (add_ass) {

src/llama-model.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6275,11 +6275,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
62756275
case LLM_ARCH_PANGU_EMBED:
62766276
{
62776277
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
6278+
6279+
// output
62786280
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
62796281
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
62806282

6281-
// openPanguEmbedded-1B model's lm_head/output is 'tie_word_embeddings', the 7B model is not
6282-
if(type == LLM_TYPE_1B){
6283+
// if output is NULL, init from the input tok embed
6284+
if(output == NULL){
62836285
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
62846286
}
62856287

@@ -6295,26 +6297,24 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
62956297
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
62966298

62976299
// bias tensors
6298-
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd_head_k * n_head}, 0);
6300+
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd_head_k * n_head}, 0);
62996301
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0);
63006302
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0);
6301-
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
6303+
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
63026304

63036305
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
63046306

63056307
if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) {
63066308
layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
63076309
layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
6308-
}
6309-
else {
6310+
} else {
63106311
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
63116312
}
63126313

63136314
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
63146315
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
63156316
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
63166317
}
6317-
63186318
} break;
63196319
default:
63206320
throw std::runtime_error("unknown architecture");

src/llama-vocab.cpp

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1805,20 +1805,6 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
18051805
special_sep_id = LLAMA_TOKEN_NULL;
18061806
special_pad_id = 3; // <|plamo:pad|>
18071807
special_mask_id = LLAMA_TOKEN_NULL;
1808-
} else if (tokenizer_model == "pangu_embedded") {
1809-
type = LLAMA_VOCAB_TYPE_SPM;
1810-
1811-
// default special tokens
1812-
special_bos_id = 1;
1813-
special_eos_id = 45892;
1814-
special_unk_id = 0;
1815-
special_sep_id = LLAMA_TOKEN_NULL;
1816-
special_pad_id = 0;
1817-
special_mask_id = LLAMA_TOKEN_NULL;
1818-
1819-
add_space_prefix = true;
1820-
add_bos = true;
1821-
add_eos = false;
18221808
} else {
18231809
throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str()));
18241810
}

src/models/models.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -317,10 +317,6 @@ struct llm_build_minimax_m2 : public llm_graph_context {
317317
llm_build_minimax_m2(const llama_model & model, const llm_graph_params & params);
318318
};
319319

320-
struct llm_build_pangu_embedded : public llm_graph_context {
321-
llm_build_pangu_embedded(const llama_model & model, const llm_graph_params & params);
322-
};
323-
324320
struct llm_build_mpt : public llm_graph_context {
325321
llm_build_mpt(const llama_model & model, const llm_graph_params & params);
326322
};
@@ -365,6 +361,10 @@ struct llm_build_orion : public llm_graph_context {
365361
llm_build_orion(const llama_model & model, const llm_graph_params & params);
366362
};
367363

364+
struct llm_build_pangu_embedded : public llm_graph_context {
365+
llm_build_pangu_embedded(const llama_model & model, const llm_graph_params & params);
366+
};
367+
368368
struct llm_build_phi2 : public llm_graph_context {
369369
llm_build_phi2(const llama_model & model, const llm_graph_params & params);
370370
};

src/models/pangu-embedded.cpp

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
#include "models.h"
2+
3+
4+
llm_build_pangu_embedded::llm_build_pangu_embedded(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
5+
const int64_t n_embd_head = hparams.n_embd_head_v;
6+
7+
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
8+
GGML_ASSERT(n_embd_head == hparams.n_rot);
9+
10+
ggml_tensor * cur;
11+
ggml_tensor * inpL;
12+
13+
inpL = build_inp_embd(model.tok_embd);
14+
15+
// inp_pos - contains the positions
16+
ggml_tensor * inp_pos = build_inp_pos();
17+
18+
auto * inp_attn = build_attn_inp_kv();
19+
20+
ggml_tensor * inp_out_ids = build_inp_out_ids();
21+
22+
for (int il = 0; il < n_layer; ++il) {
23+
ggml_tensor * inpSA = inpL;
24+
25+
// norm
26+
cur = build_norm(inpL,
27+
model.layers[il].attn_norm, NULL,
28+
LLM_NORM_RMS, il);
29+
cb(cur, "attn_norm", il);
30+
31+
// self attention
32+
{
33+
// compute Q and K and RoPE them
34+
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
35+
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
36+
cb(Qcur, "Qcur", il);
37+
38+
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
39+
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
40+
cb(Kcur, "Kcur", il);
41+
42+
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
43+
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
44+
cb(Vcur, "Vcur", il);
45+
46+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
47+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
48+
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
49+
50+
Qcur = ggml_rope_ext(
51+
ctx0, Qcur, inp_pos, nullptr,
52+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
53+
ext_factor, attn_factor, beta_fast, beta_slow
54+
);
55+
56+
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr,
57+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
58+
ext_factor, attn_factor, beta_fast, beta_slow
59+
);
60+
61+
cb(Qcur, "Qcur", il);
62+
cb(Kcur, "Kcur", il);
63+
cb(Vcur, "Vcur", il);
64+
65+
cur = build_attn(inp_attn,
66+
model.layers[il].wo, model.layers[il].bo,
67+
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
68+
}
69+
70+
if (il == n_layer - 1 && inp_out_ids) {
71+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
72+
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
73+
}
74+
75+
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
76+
cb(ffn_inp, "ffn_inp", il);
77+
78+
// feed-forward network
79+
cur = build_norm(ffn_inp,
80+
model.layers[il].ffn_norm, NULL,
81+
LLM_NORM_RMS, il);
82+
cb(cur, "ffn_norm", il);
83+
84+
cur = build_ffn(cur,
85+
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
86+
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
87+
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
88+
NULL,
89+
LLM_FFN_SILU, LLM_FFN_PAR, il);
90+
91+
cur = ggml_add(ctx0, cur, ffn_inp);
92+
cb(cur, "ffn_out", il);
93+
94+
cur = build_cvec(cur, il);
95+
cb(cur, "l_out", il);
96+
97+
// input for next layer
98+
inpL = cur;
99+
}
100+
101+
cur = inpL;
102+
103+
cur = build_norm(cur,
104+
model.output_norm, NULL,
105+
LLM_NORM_RMS, -1);
106+
107+
cb(cur, "result_norm", -1);
108+
res->t_embd = cur;
109+
110+
// lm_head
111+
cur = build_lora_mm(model.output, cur);
112+
113+
if (model.output_b != nullptr) {
114+
cur = ggml_add(ctx0, cur, model.output_b);
115+
}
116+
117+
cb(cur, "result_output", -1);
118+
res->t_logits = cur;
119+
120+
ggml_build_forward_expand(gf, cur);
121+
}

0 commit comments

Comments
 (0)