Skip to content

Commit 4cf8a56

Browse files
Merge pull request #1 from nlasky2000-dot/fix-mistral3-attn-temp-scaling
fix: correct attention temperature scaling formula for Mistral3/Devstral
2 parents c00ff92 + 63bdc05 commit 4cf8a56

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

src/llama-graph.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,11 @@ void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
7777
std::vector<float> attn_scale_data(n_tokens, 0.0f);
7878
for (int i = 0; i < n_tokens; ++i) {
7979
const float pos = ubatch->pos[i];
80-
attn_scale_data[i] = std::log(
81-
std::floor((pos + 1.0f) / n_attn_temp_floor_scale) + 1.0
82-
) * f_attn_temp_scale + 1.0;
80+
// ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/ministral3/modeling_ministral3.py#L101
81+
// scaling = 1 + beta * log(1 + floor(pos / max_position_embeddings))
82+
attn_scale_data[i] = 1.0f + f_attn_temp_scale * std::log(
83+
1.0f + std::floor(pos / n_attn_temp_floor_scale)
84+
);
8385
}
8486

8587
ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*ggml_element_size(attn_scale));

0 commit comments

Comments
 (0)