Skip to content

Commit 050abed

Browse files
committed
Revert some files
1 parent 5c04994 commit 050abed

File tree

4 files changed

+4970
-4670
lines changed

4 files changed

+4970
-4670
lines changed

convert_hf_to_gguf.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2190,11 +2190,52 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
21902190
@ModelBase.register("Plamo2ForCausalLM")
21912191
class Plamo2Model(LlamaModel):
21922192
model_arch = gguf.MODEL_ARCH.PLAMO2
2193-
21942193

21952194
def set_vocab(self):
2196-
# Plamo2 uses sentencepiece tokenizer similar to Llama
2197-
self._set_vocab_sentencepiece()
2195+
dir_model = self.dir_model
2196+
hparams = self.hparams
2197+
2198+
tokens: list[bytes] = []
2199+
toktypes: list[int] = []
2200+
2201+
from transformers import AutoTokenizer
2202+
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
2203+
vocab_size = len(tokenizer.vocab)
2204+
# Since we are checking the maximum index, we need to ensure it's strictly less than vocab_size,
2205+
# because vocab_size is the count of items, and indexes start at 0.
2206+
max_vocab_index = max(tokenizer.get_vocab().values())
2207+
if max_vocab_index >= vocab_size:
2208+
raise ValueError("Vocabulary size exceeds expected maximum size.")
2209+
2210+
reverse_vocab: dict[int, str] = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
2211+
added_vocab = tokenizer.get_added_vocab()
2212+
2213+
for token_id in range(vocab_size):
2214+
token_text = reverse_vocab[token_id].encode('utf-8')
2215+
# replace "\x00" to string with length > 0
2216+
if token_text == b"\x00":
2217+
toktype = gguf.TokenType.BYTE # special
2218+
token_text = f"<{token_text}>".encode('utf-8')
2219+
elif re.fullmatch(br"<0x[0-9A-Fa-f]{2}>", token_text):
2220+
toktype = gguf.TokenType.BYTE # special
2221+
elif reverse_vocab[token_id] in added_vocab:
2222+
if tokenizer.added_tokens_decoder[token_id].special:
2223+
toktype = gguf.TokenType.CONTROL
2224+
else:
2225+
toktype = gguf.TokenType.USER_DEFINED
2226+
else:
2227+
toktype = gguf.TokenType.NORMAL
2228+
2229+
tokens.append(token_text)
2230+
toktypes.append(toktype)
2231+
2232+
# self.gguf_writer.add_tokenizer_model("llama")
2233+
# self.gguf_writer.add_tokenizer_pre("default")
2234+
self.gguf_writer.add_token_list(tokens)
2235+
self.gguf_writer.add_token_types(toktypes)
2236+
2237+
special_vocab = gguf.SpecialVocab(dir_model, n_vocab=len(tokens))
2238+
special_vocab.add_to_gguf(self.gguf_writer)
21982239

21992240
def set_gguf_parameters(self):
22002241
super().set_gguf_parameters()

src/llama-hparams.cpp

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,69 @@
22

33
#include "ggml.h"
44

5-
// Only define functions that are not already inline in the header
5+
void llama_hparams::set_swa_pattern(uint32_t n_pattern) {
6+
for (uint32_t il = 0; il < n_layer; ++il) {
7+
swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1));
8+
}
9+
}
10+
11+
bool llama_hparams::is_swa_any() const {
12+
for (uint32_t il = 0; il < n_layer; ++il) {
13+
if (swa_layers[il]) {
14+
return true;
15+
}
16+
}
17+
18+
return false;
19+
}
20+
21+
uint32_t llama_hparams::n_head(uint32_t il) const {
22+
if (il < n_layer) {
23+
return n_head_arr[il];
24+
}
25+
26+
GGML_ABORT("fatal error");
27+
}
28+
29+
uint32_t llama_hparams::n_head_kv(uint32_t il) const {
30+
if (il < n_layer) {
31+
return n_head_kv_arr[il];
32+
}
33+
34+
GGML_ABORT("fatal error");
35+
}
36+
37+
uint32_t llama_hparams::n_ff(uint32_t il) const {
38+
if (il < n_layer) {
39+
return n_ff_arr[il];
40+
}
41+
42+
GGML_ABORT("fatal error");
43+
}
44+
45+
uint32_t llama_hparams::n_gqa(uint32_t il) const {
46+
const uint32_t n_head = this->n_head(il);
47+
const uint32_t n_head_kv = this->n_head_kv(il);
48+
49+
if (n_head_kv == 0) {
50+
return 0;
51+
}
52+
53+
return n_head/n_head_kv;
54+
}
55+
56+
uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const {
57+
const uint32_t n_head_kv = this->n_head_kv(il);
58+
59+
return n_embd_head_k * n_head_kv;
60+
}
61+
62+
uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
63+
const uint32_t n_head_kv = this->n_head_kv(il);
64+
65+
return n_embd_head_v * n_head_kv;
66+
}
67+
668
uint32_t llama_hparams::n_embd_k_s() const {
769
if (wkv_head_size != 0) {
870
// for RWKV models
@@ -22,4 +84,12 @@ uint32_t llama_hparams::n_embd_v_s() const {
2284

2385
// corresponds to Mamba's ssm_states size
2486
return ssm_d_state * ssm_d_inner;
25-
}
87+
}
88+
89+
bool llama_hparams::is_swa(uint32_t il) const {
90+
if (il < n_layer) {
91+
return swa_layers[il];
92+
}
93+
94+
GGML_ABORT("fatal error");
95+
}

src/llama-hparams.h

Lines changed: 174 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,25 @@
11
#pragma once
22

3-
#include "llama.h" // This now provides the primary definitions
3+
#include "llama.h"
44

5-
// #include <array> // llama.h includes this
5+
#include <array>
66

7-
// Internal constant if not defined in the public API
7+
// bump if necessary
8+
#define LLAMA_MAX_LAYERS 512
89
#define LLAMA_MAX_EXPERTS 256 // DeepSeekV3
910

10-
// Internal helper structs if they are not part of the public API
11-
// and are used by files including src/llama-hparams.h
12-
// If these are actually part of the public llama_hparams, they should be in include/llama.h
13-
// For now, assuming they might be used by other src files that include this.
11+
enum llama_expert_gating_func_type {
12+
LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0,
13+
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX = 1,
14+
LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2,
15+
};
16+
17+
enum llama_swa_type {
18+
LLAMA_SWA_TYPE_NONE = 0,
19+
LLAMA_SWA_TYPE_STANDARD = 1,
20+
LLAMA_SWA_TYPE_CHUNKED = 2,
21+
};
22+
1423
struct llama_hparams_posnet {
1524
uint32_t n_embd;
1625
uint32_t n_layer;
@@ -21,7 +30,161 @@ struct llama_hparams_convnext {
2130
uint32_t n_layer;
2231
};
2332

24-
// All other definitions previously in this file (LLAMA_MAX_LAYERS,
25-
// enum llama_expert_gating_func_type, enum llama_swa_type,
26-
// struct llama_hparams, and the static_assert) are removed
27-
// to defer to the definitions in "llama.h".
33+
struct llama_hparams {
34+
bool vocab_only;
35+
bool rope_finetuned;
36+
bool use_par_res;
37+
bool swin_norm;
38+
39+
uint32_t n_ctx_train; // context size the model was trained on
40+
uint32_t n_embd;
41+
uint32_t n_embd_features = 0;
42+
uint32_t n_layer;
43+
uint32_t n_rot;
44+
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
45+
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
46+
uint32_t n_expert = 0;
47+
uint32_t n_expert_used = 0;
48+
uint32_t n_rel_attn_bkts = 0;
49+
50+
// note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
51+
uint32_t n_embd_head_k_mla = 0;
52+
uint32_t n_embd_head_v_mla = 0;
53+
54+
// for WavTokenizer
55+
struct llama_hparams_posnet posnet;
56+
struct llama_hparams_convnext convnext;
57+
58+
std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_arr;
59+
std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_kv_arr;
60+
std::array<uint32_t, LLAMA_MAX_LAYERS> n_ff_arr;
61+
62+
uint32_t n_layer_dense_lead = 0;
63+
uint32_t n_lora_q = 0;
64+
uint32_t n_lora_kv = 0;
65+
uint32_t n_ff_exp = 0;
66+
uint32_t n_ff_shexp = 0;
67+
uint32_t n_expert_shared = 0;
68+
uint32_t n_norm_groups = 0;
69+
70+
float expert_weights_scale = 0.0;
71+
bool expert_weights_norm = false;
72+
uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE;
73+
uint32_t moe_every_n_layers = 0;
74+
75+
float f_norm_eps;
76+
float f_norm_rms_eps;
77+
float f_norm_group_eps;
78+
79+
float f_attn_logit_softcapping = 50.0f;
80+
float f_final_logit_softcapping = 30.0f;
81+
82+
// for RWKV
83+
uint32_t rescale_every_n_layers = 0;
84+
uint32_t time_mix_extra_dim = 0;
85+
uint32_t time_decay_extra_dim = 0;
86+
uint32_t wkv_head_size = 0;
87+
uint32_t token_shift_count = 2;
88+
uint32_t n_lora_decay = 0;
89+
uint32_t n_lora_iclr = 0;
90+
uint32_t n_lora_value_res_mix = 0;
91+
uint32_t n_lora_gate = 0;
92+
93+
float rope_attn_factor = 1.0f;
94+
float rope_freq_base_train;
95+
float rope_freq_base_train_swa;
96+
float rope_freq_scale_train;
97+
float rope_freq_scale_train_swa;
98+
uint32_t n_ctx_orig_yarn;
99+
float rope_yarn_log_mul;
100+
101+
std::array<int, 4> rope_sections;
102+
103+
// Sliding Window Attention (SWA)
104+
llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
105+
// the size of the sliding window (0 - no SWA)
106+
uint32_t n_swa = 0;
107+
// if swa_layers[il] == true, then layer il is SWA
108+
// if swa_layers[il] == false, then layer il is dense (i.e. non-SWA)
109+
// by default, all layers are dense
110+
std::array<bool, LLAMA_MAX_LAYERS> swa_layers;
111+
112+
// for State Space Models
113+
uint32_t ssm_d_conv = 0;
114+
uint32_t ssm_d_inner = 0;
115+
uint32_t ssm_d_state = 0;
116+
uint32_t ssm_dt_rank = 0;
117+
118+
bool ssm_dt_b_c_rms = false;
119+
120+
float f_clamp_kqv = 0.0f;
121+
float f_max_alibi_bias = 0.0f;
122+
float f_logit_scale = 0.0f;
123+
124+
// Additional scale factors (Granite/Granite MoE)
125+
float f_residual_scale = 0.0f;
126+
float f_embedding_scale = 0.0f;
127+
float f_attention_scale = 0.0f;
128+
129+
bool causal_attn = true;
130+
bool use_alibi = false;
131+
bool attn_soft_cap = false;
132+
bool use_kq_norm = true;
133+
134+
// llama4
135+
uint32_t n_moe_layer_step = 0;
136+
uint32_t n_no_rope_layer_step = 4;
137+
uint32_t n_attn_temp_floor_scale = 8192;
138+
float f_attn_temp_scale = 0.1;
139+
140+
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
141+
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
142+
llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
143+
144+
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
145+
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
146+
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
147+
148+
// this value n_pattern means that every nth layer is dense (i.e. non-SWA)
149+
// note that if n_pattern == 0, all layers are SWA
150+
// if n_pattern == 1, all layers are dense
151+
// example: n_pattern = 3
152+
// il == 0: swa
153+
// il == 1: swa
154+
// il == 2: dense
155+
// il == 3: swa
156+
// il == 4: swa
157+
// il == 5: dense
158+
// il == 6: swa
159+
// etc ...
160+
void set_swa_pattern(uint32_t n_pattern);
161+
162+
// return true if one of the layers is SWA
163+
bool is_swa_any() const;
164+
165+
uint32_t n_head(uint32_t il = 0) const;
166+
167+
uint32_t n_head_kv(uint32_t il = 0) const;
168+
169+
uint32_t n_ff(uint32_t il = 0) const;
170+
171+
uint32_t n_gqa(uint32_t il = 0) const;
172+
173+
// dimension of key embeddings across all k-v heads
174+
uint32_t n_embd_k_gqa(uint32_t il = 0) const;
175+
176+
// dimension of value embeddings across all k-v heads
177+
uint32_t n_embd_v_gqa(uint32_t il = 0) const;
178+
179+
// dimension of the rolling state embeddings
180+
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
181+
uint32_t n_embd_k_s() const;
182+
183+
// dimension of the recurrent state embeddings
184+
uint32_t n_embd_v_s() const;
185+
186+
bool is_swa(uint32_t il) const;
187+
};
188+
189+
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
190+

0 commit comments

Comments
 (0)