Skip to content

Commit 609a2d0

Browse files
authored
models : fix YaRN regression + consolidate logic (ggml-org#18006)
* models : fix YaRN regression + consolidate logic * cont : fix the fix * cont : remove header * cont : add header
1 parent a63cbaf commit 609a2d0

File tree

6 files changed

+40
-46
lines changed

6 files changed

+40
-46
lines changed

src/llama-context.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "llama-model.h"
1010

1111
#include <cinttypes>
12+
#include <cmath>
1213
#include <cstring>
1314
#include <limits>
1415
#include <stdexcept>
@@ -72,6 +73,43 @@ llama_context::llama_context(
7273
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
7374
}
7475

76+
if (cparams.yarn_ext_factor != 0) {
77+
static auto get_mscale = [](float scale, float mscale) {
78+
return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f);
79+
};
80+
81+
const float factor = 1.0f / cparams.rope_freq_scale;
82+
83+
// ref: https://github.com/huggingface/transformers/blob/6d00f6b0a5679c36510f203e4226e36f517c3032/src/transformers/modeling_rope_utils.py#L336-L348
84+
if (hparams.rope_yarn_log_mul != 0.0f) {
85+
// note: here we assume `mscale == 1.0f`
86+
// TODO: start reading the actual value of mscale and handle the case where it is not 1.0f
87+
float mscale = 1.0f;
88+
const float mscale_all_dims = hparams.rope_yarn_log_mul;
89+
90+
// [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
91+
// special-case DEEPSEEK v2:
92+
// https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat/blob/main/config.json#L42-L43
93+
if (model.arch == LLM_ARCH_DEEPSEEK2 && mscale_all_dims != 1.0f) {
94+
mscale = mscale_all_dims;
95+
}
96+
97+
cparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims);
98+
99+
LLAMA_LOG_WARN("%s: setting new yarn_attn_factor = %.4f (mscale == %.1f, mscale_all_dim = %.1f)\n",
100+
__func__, cparams.yarn_attn_factor, mscale, mscale_all_dims);
101+
} else {
102+
cparams.yarn_attn_factor = get_mscale(factor, 1.0f);
103+
}
104+
105+
// when YARN is applied with yarn_ext_factor != 0.0f, we need to cancel this factor:
106+
// https://github.com/ggml-org/llama.cpp/blob/a81a569577cc38b32558958b048228150be63eae/ggml/src/ggml-cpu/ops.cpp#L5541-L5544
107+
//
108+
// ref: https://github.com/ggml-org/llama.cpp/discussions/7416
109+
// https://github.com/ggml-org/llama.cpp/pull/17945
110+
cparams.yarn_attn_factor *= 1.0f / (1.0f + 0.1f * logf(factor));
111+
}
112+
75113
cparams.yarn_attn_factor *= hparams.rope_attn_factor;
76114

77115
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {

src/llama-graph.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
574574
freq_base (cparams.rope_freq_base),
575575
freq_scale (cparams.rope_freq_scale),
576576
ext_factor (cparams.yarn_ext_factor),
577-
attn_factor (llama_hparams::yarn_attn_factor_adjust(cparams.yarn_attn_factor, cparams.rope_freq_scale, cparams.yarn_ext_factor)),
577+
attn_factor (cparams.yarn_attn_factor),
578578
beta_fast (cparams.yarn_beta_fast),
579579
beta_slow (cparams.yarn_beta_slow),
580580
norm_eps (hparams.f_norm_eps),

src/llama-hparams.cpp

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#include "ggml.h"
44

55
#include <cassert>
6-
#include <cmath>
76

87
void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) {
98
if (dense_first) {
@@ -231,13 +230,3 @@ bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama
231230

232231
return false;
233232
}
234-
235-
float llama_hparams::yarn_attn_factor_adjust(float attn_factor, float freq_scale, float ext_factor) {
236-
GGML_ASSERT(ext_factor >= 0.0f);
237-
238-
if (ext_factor != 0.0f) {
239-
attn_factor *= 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale));
240-
}
241-
242-
return attn_factor;
243-
}

src/llama-hparams.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -268,13 +268,6 @@ struct llama_hparams {
268268
// TODO: think of a better place for this function
269269
// TODO: pack the SWA params in a struct?
270270
static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1);
271-
272-
// when YARN is applied with yarn_ext_factor != 0.0f, we need to cancel this factor:
273-
// https://github.com/ggml-org/llama.cpp/blob/a81a569577cc38b32558958b048228150be63eae/ggml/src/ggml-cpu/ops.cpp#L5541-L5544
274-
//
275-
// ref: https://github.com/ggml-org/llama.cpp/discussions/7416
276-
// https://github.com/ggml-org/llama.cpp/pull/17945
277-
static float yarn_attn_factor_adjust(float attn_factor, float freq_scale, float ext_factor);
278271
};
279272

280273
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");

src/llama-kv-cache.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1372,7 +1372,7 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
13721372
const auto & yarn_ext_factor = cparams.yarn_ext_factor;
13731373
const auto & yarn_beta_fast = cparams.yarn_beta_fast;
13741374
const auto & yarn_beta_slow = cparams.yarn_beta_slow;
1375-
const auto & yarn_attn_factor = llama_hparams::yarn_attn_factor_adjust(cparams.yarn_attn_factor, cparams.rope_freq_scale, cparams.yarn_ext_factor);
1375+
const auto & yarn_attn_factor = cparams.yarn_attn_factor;
13761376

13771377
const auto & n_rot = hparams.n_rot;
13781378
const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE

src/llama-model.cpp

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2294,32 +2294,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
22942294
default: throw std::runtime_error("unsupported model architecture");
22952295
}
22962296

2297-
// ref: https://github.com/huggingface/transformers/blob/6d00f6b0a5679c36510f203e4226e36f517c3032/src/transformers/modeling_rope_utils.py#L336-L348
2298-
if (hparams.rope_yarn_log_mul != 0.0f) {
2299-
const float factor = 1.0f / hparams.rope_freq_scale_train;
2300-
2301-
// note: here we assume `mscale == 1.0f`
2302-
// TODO: start reading the actual value of mscale and handle the case where it is not 1.0f
2303-
float mscale = 1.0f;
2304-
const float mscale_all_dims = hparams.rope_yarn_log_mul;
2305-
2306-
// [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
2307-
// special-case DEEPSEEK v2:
2308-
// https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat/blob/main/config.json#L42-L43
2309-
if (arch == LLM_ARCH_DEEPSEEK2 && mscale_all_dims != 1.0f) {
2310-
mscale = mscale_all_dims;
2311-
}
2312-
2313-
static auto get_mscale = [](float scale, float mscale) {
2314-
return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f);
2315-
};
2316-
2317-
hparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims);
2318-
2319-
LLAMA_LOG_WARN("%s: setting new yarn_attn_factor = %.4f (mscale == %.1f, mscale_all_dim = %.1f)\n",
2320-
__func__, hparams.yarn_attn_factor, mscale, mscale_all_dims);
2321-
}
2322-
23232297
pimpl->n_bytes = ml.n_bytes;
23242298

23252299
pimpl->desc_str = arch_name() + " " + type_name() + " " + ml.ftype_name();

0 commit comments

Comments
 (0)