Skip to content

Commit 0759b09

Browse files
authored
graph: add f_attn_temp_offset (#18025)
1 parent 254098a commit 0759b09

File tree

4 files changed

+11
-4
lines changed

4 files changed

+11
-4
lines changed

src/llama-graph.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
7878
for (int i = 0; i < n_tokens; ++i) {
7979
const float pos = ubatch->pos[i];
8080
attn_scale_data[i] = std::log(
81-
std::floor((pos + 1.0f) / n_attn_temp_floor_scale) + 1.0
81+
std::floor((pos + f_attn_temp_offset) / n_attn_temp_floor_scale) + 1.0
8282
) * f_attn_temp_scale + 1.0;
8383
}
8484

@@ -1203,7 +1203,7 @@ ggml_tensor * llm_graph_context::build_inp_pos() const {
12031203
}
12041204

12051205
ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
1206-
auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
1206+
auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale, hparams.f_attn_temp_offset);
12071207

12081208
auto & cur = inp->attn_scale;
12091209

src/llama-graph.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,8 @@ class llm_graph_input_pos : public llm_graph_input_i {
132132
// temperature tuning, used by llama4
133133
class llm_graph_input_attn_temp : public llm_graph_input_i {
134134
public:
135-
llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
136-
: n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
135+
llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale, float f_attn_temp_offset)
136+
: n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale), f_attn_temp_offset(f_attn_temp_offset) {}
137137
virtual ~llm_graph_input_attn_temp() = default;
138138

139139
void set_input(const llama_ubatch * ubatch) override;
@@ -142,6 +142,7 @@ class llm_graph_input_attn_temp : public llm_graph_input_i {
142142

143143
const uint32_t n_attn_temp_floor_scale;
144144
const float f_attn_temp_scale;
145+
const float f_attn_temp_offset;
145146
};
146147

147148
class llm_graph_input_pos_bucket : public llm_graph_input_i {

src/llama-hparams.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ struct llama_hparams {
165165
uint32_t n_no_rope_layer_step = 4;
166166
uint32_t n_attn_temp_floor_scale = 0;
167167
float f_attn_temp_scale = 0.0f;
168+
float f_attn_temp_offset = 0.0f; // offset position index
168169

169170
// gemma3n altup
170171
uint32_t n_altup = 4; // altup_num_inputs

src/llama-model.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
668668
hparams.n_swa = 8192;
669669
hparams.n_attn_temp_floor_scale = 8192;
670670
hparams.f_attn_temp_scale = 0.1f;
671+
hparams.f_attn_temp_offset = 1.0f;
671672
hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full
672673
}
673674

@@ -1646,6 +1647,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
16461647
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false);
16471648
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.n_attn_temp_floor_scale, false);
16481649

1650+
hparams.f_attn_temp_offset = 0.0f;
1651+
16491652
switch (hparams.n_layer) {
16501653
case 27: type = LLM_TYPE_16B; break;
16511654
case 60: type = LLM_TYPE_236B; break;
@@ -2276,6 +2279,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
22762279
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false);
22772280
ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f);
22782281

2282+
hparams.f_attn_temp_offset = 0.0f;
2283+
22792284
// TODO: maybe add n_attn_temp_floor_scale as a separate KV?
22802285
if (hparams.f_attn_temp_scale != 0.0f) {
22812286
hparams.n_attn_temp_floor_scale = hparams.n_ctx_orig_yarn;

0 commit comments

Comments
 (0)