Skip to content

Commit d032a1b

Browse files
authored
add yarn metadata, move defaults to hparams
1 parent ed4d8f2 commit d032a1b

File tree

9 files changed

+68
-34
lines changed

9 files changed

+68
-34
lines changed

common/common.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -287,9 +287,9 @@ struct common_params {
287287
float rope_freq_base = 0.0f; // RoPE base frequency
288288
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
289289
float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor
290-
float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor
291-
float yarn_beta_fast = 32.0f; // YaRN low correction dim
292-
float yarn_beta_slow = 1.0f; // YaRN high correction dim
290+
float yarn_attn_factor = -1.0f; // YaRN magnitude scaling factor
291+
float yarn_beta_fast = -1.0f; // YaRN low correction dim
292+
float yarn_beta_slow = -1.0f; // YaRN high correction dim
293293
int32_t yarn_orig_ctx = 0; // YaRN original context length
294294

295295
// offload params

convert_hf_to_gguf.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2722,6 +2722,10 @@ def set_gguf_parameters(self):
27222722
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
27232723
self.gguf_writer.add_rope_scaling_factor(self.hparams["scaling_factor"])
27242724
self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["original_max_position_embeddings"])
2725+
self.gguf_writer.add_rope_scaling_yarn_ext_factor(self.hparams["extrapolation_factor"])
2726+
self.gguf_writer.add_rope_scaling_yarn_attn_factor(self.hparams["attn_factor"])
2727+
self.gguf_writer.add_rope_scaling_yarn_beta_fast(self.hparams["beta_fast"])
2728+
self.gguf_writer.add_rope_scaling_yarn_beta_slow(self.hparams["beta_slow"])
27252729

27262730
if temp_len := self.hparams.get("attn_temperature_len"):
27272731
self.gguf_writer.add_attn_temperature_length(temp_len)

gguf-py/gguf/constants.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -154,15 +154,19 @@ class Attention:
154154
SLIDING_WINDOW_PATTERN = "{arch}.attention.sliding_window_pattern"
155155

156156
class Rope:
157-
DIMENSION_COUNT = "{arch}.rope.dimension_count"
158-
DIMENSION_SECTIONS = "{arch}.rope.dimension_sections"
159-
FREQ_BASE = "{arch}.rope.freq_base"
160-
SCALING_TYPE = "{arch}.rope.scaling.type"
161-
SCALING_FACTOR = "{arch}.rope.scaling.factor"
162-
SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor"
163-
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
164-
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
165-
SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier"
157+
DIMENSION_COUNT = "{arch}.rope.dimension_count"
158+
DIMENSION_SECTIONS = "{arch}.rope.dimension_sections"
159+
FREQ_BASE = "{arch}.rope.freq_base"
160+
SCALING_TYPE = "{arch}.rope.scaling.type"
161+
SCALING_FACTOR = "{arch}.rope.scaling.factor"
162+
SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor"
163+
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
164+
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
165+
SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier"
166+
SCALING_YARN_EXT_FACTOR = "{arch}.rope.scaling.yarn_ext_factor"
167+
SCALING_YARN_ATTN_FACTOR = "{arch}.rope.scaling.yarn_attn_factor"
168+
SCALING_YARN_BETA_FAST = "{arch}.rope.scaling.yarn_beta_fast"
169+
SCALING_YARN_BETA_SLOW = "{arch}.rope.scaling.yarn_beta_slow"
166170

167171
class Split:
168172
LLM_KV_SPLIT_NO = "split.no"

gguf-py/gguf/gguf_writer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -865,6 +865,18 @@ def add_rope_scaling_finetuned(self, value: bool) -> None:
865865
def add_rope_scaling_yarn_log_mul(self, value: float) -> None:
866866
self.add_float32(Keys.Rope.SCALING_YARN_LOG_MUL.format(arch=self.arch), value)
867867

868+
def add_rope_scaling_yarn_ext_factor(self, value: float) -> None:
869+
self.add_float32(Keys.Rope.SCALING_YARN_EXT_FACTOR.format(arch=self.arch), value)
870+
871+
def add_rope_scaling_yarn_attn_factor(self, value: float) -> None:
872+
self.add_float32(Keys.Rope.SCALING_YARN_ATTN_FACTOR.format(arch=self.arch), value)
873+
874+
def add_rope_scaling_yarn_beta_fast(self, value: float) -> None:
875+
self.add_float32(Keys.Rope.SCALING_YARN_BETA_FAST.format(arch=self.arch), value)
876+
877+
def add_rope_scaling_yarn_beta_slow(self, value: float) -> None:
878+
self.add_float32(Keys.Rope.SCALING_YARN_BETA_SLOW.format(arch=self.arch), value)
879+
868880
def add_ssm_conv_kernel(self, value: int) -> None:
869881
self.add_uint32(Keys.SSM.CONV_KERNEL.format(arch=self.arch), value)
870882

src/llama-arch.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -174,16 +174,20 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
174174
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
175175
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
176176

177-
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
178-
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
179-
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
180-
{ LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" },
181-
{ LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" },
182-
{ LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" },
183-
{ LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" },
184-
{ LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" },
185-
{ LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" },
186-
{ LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" },
177+
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
178+
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
179+
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
180+
{ LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" },
181+
{ LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" },
182+
{ LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" },
183+
{ LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" },
184+
{ LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" },
185+
{ LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" },
186+
{ LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" },
187+
{ LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, "%s.rope.scaling.yarn_ext_factor" },
188+
{ LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, "%s.rope.scaling.yarn_attn_factor" },
189+
{ LLM_KV_ROPE_SCALING_YARN_BETA_FAST, "%s.rope.scaling.yarn_beta_fast" },
190+
{ LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, "%s.rope.scaling.yarn_beta_slow" },
187191

188192
{ LLM_KV_SPLIT_NO, "split.no" },
189193
{ LLM_KV_SPLIT_COUNT, "split.count" },

src/llama-arch.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,10 @@ enum llm_kv {
188188
LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
189189
LLM_KV_ROPE_SCALING_FINETUNED,
190190
LLM_KV_ROPE_SCALING_YARN_LOG_MUL,
191+
LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR,
192+
LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR,
193+
LLM_KV_ROPE_SCALING_YARN_BETA_FAST,
194+
LLM_KV_ROPE_SCALING_YARN_BETA_SLOW,
191195

192196
LLM_KV_SPLIT_NO,
193197
LLM_KV_SPLIT_COUNT,

src/llama-context.cpp

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@ llama_context::llama_context(
3535

3636
cparams.n_threads = params.n_threads;
3737
cparams.n_threads_batch = params.n_threads_batch;
38-
cparams.yarn_ext_factor = params.yarn_ext_factor;
39-
cparams.yarn_attn_factor = params.yarn_attn_factor;
40-
cparams.yarn_beta_fast = params.yarn_beta_fast;
41-
cparams.yarn_beta_slow = params.yarn_beta_slow;
38+
cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor;
39+
cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor;
40+
cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast;
41+
cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow;
4242
cparams.embeddings = params.embeddings;
4343
cparams.offload_kqv = params.offload_kqv;
4444
cparams.no_perf = params.no_perf;
@@ -69,10 +69,6 @@ llama_context::llama_context(
6969
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
7070
}
7171

72-
if (model.arch == LLM_ARCH_GROK && params.yarn_beta_fast == 32.0f) {
73-
cparams.yarn_beta_fast = 8.0f;
74-
}
75-
7672
cparams.yarn_attn_factor *= hparams.rope_attn_factor;
7773

7874
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
@@ -2265,9 +2261,9 @@ llama_context_params llama_context_default_params() {
22652261
/*.rope_freq_base =*/ 0.0f,
22662262
/*.rope_freq_scale =*/ 0.0f,
22672263
/*.yarn_ext_factor =*/ -1.0f,
2268-
/*.yarn_attn_factor =*/ 1.0f,
2269-
/*.yarn_beta_fast =*/ 32.0f,
2270-
/*.yarn_beta_slow =*/ 1.0f,
2264+
/*.yarn_attn_factor =*/ -1.0f,
2265+
/*.yarn_beta_fast =*/ -1.0f,
2266+
/*.yarn_beta_slow =*/ -1.0f,
22712267
/*.yarn_orig_ctx =*/ 0,
22722268
/*.defrag_thold =*/ -1.0f,
22732269
/*.cb_eval =*/ nullptr,

src/llama-hparams.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,11 @@ struct llama_hparams {
105105
uint32_t n_ctx_orig_yarn;
106106
float rope_yarn_log_mul = 0.0f;
107107

108+
float yarn_ext_factor = -1.0f;
109+
float yarn_attn_factor = 1.0f;
110+
float yarn_beta_fast = 32.0f;
111+
float yarn_beta_slow = 1.0f;
112+
108113
std::array<int, 4> rope_sections;
109114

110115
// Sliding Window Attention (SWA)

src/llama-model.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
686686
case LLM_ARCH_GROK:
687687
{
688688
// defaults for old GGUFs
689+
hparams.yarn_beta_fast = 8.0f;
689690
hparams.f_logit_scale = 0.5773502691896257f;
690691
hparams.f_embedding_scale = 78.38367176906169f;
691692
hparams.f_attn_out_scale = 0.08838834764831845f;
@@ -703,7 +704,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
703704
ml.get_key(LLM_KV_ROUTER_LOGIT_SOFTCAPPING, hparams.f_router_logit_softcapping, false);
704705
ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
705706

706-
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.attn_temp_length, false);
707+
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.attn_temp_length, false);
708+
ml.get_key(LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, hparams.yarn_ext_factor, false);
709+
ml.get_key(LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, hparams.yarn_attn_factor, false);
710+
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false);
711+
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false);
707712

708713
switch (hparams.n_layer) {
709714
case 64: type = LLM_TYPE_314B; break;

0 commit comments

Comments
 (0)