Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3788,7 +3788,7 @@ def set_gguf_parameters(self):
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 32))
self.gguf_writer.add_layer_norm_rms_eps(hparams.get("rms_norm_eps", 1e-06))
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1000000.0))
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 10000))

# Mamba parameters
self.gguf_writer.add_ssm_state_size(hparams.get("mamba_d_state", 64))
Expand All @@ -3799,7 +3799,7 @@ def set_gguf_parameters(self):
self.gguf_writer.add_ssm_group_count(0)

# MLP feed forward parameters (for attention layers)
self.gguf_writer.add_feed_forward_length(hparams.get("intermediate_size", 16384))
self.gguf_writer.add_feed_forward_length(hparams.get("intermediate_size", 13312))
self.gguf_writer.add_file_type(self.ftype)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
Expand Down
19 changes: 9 additions & 10 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15988,7 +15988,7 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
{
// PLaMo-2 uses combined QKV tensor
ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur);
cb(qkv, "qkv", il);
cb(qkv, "wqkv", il);

// split QKV tensor into Q, K, V
const int64_t n_embd_head_q = hparams.n_embd_head_k;
Expand All @@ -16012,7 +16012,6 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {

Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
cb(Qcur, "Qcur_normed", il);

Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
Expand All @@ -16021,17 +16020,16 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {

Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
cb(Kcur, "Kcur_normed", il);

Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);

cur = build_attn(inp, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f, il);
cur = build_attn(inp, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head_v)), il);
}

cb(cur, "attn_out", il);
cb(cur, "attn_output", il);

return cur;
}
Expand Down Expand Up @@ -16103,8 +16101,9 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
ggml_build_forward_expand(gf,
ggml_cpy(ctx0, last_conv,
ggml_view_1d(ctx0, conv_states_all,
(d_conv - 1)*(d_inner)*(n_seqs),
kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all))));
(d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs),
kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all))));
cb(conv_states_all, "mamba_conv1d_state", il);

// 1D convolution
x = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d);
Expand Down Expand Up @@ -16167,9 +16166,9 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
// store last states
ggml_build_forward_expand(gf,
ggml_cpy(ctx0,
ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, x->nb[3]*x->ne[3]),
ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs,
kv_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
ggml_view_1d(ctx0, y_ssm, n_heads*head_dim*d_state*n_seqs, n_heads*head_dim*n_seq_tokens*n_seqs*ggml_element_size(y_ssm)),
ggml_view_1d(ctx0, ssm_states_all, n_heads*head_dim*d_state*n_seqs, kv_head*n_seqs*n_heads*head_dim*d_state*ggml_element_size(ssm_states_all))));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mind commenting on these changes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes actually replace d_inner with n_heads*head_dim for both of y_ssm and ssm_states_all. This is because I thought it's more natural because the ssm state created in get_ssm_rows() has the shape of (d_state , head_dim, n_heads, n_seqs) here:

ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_heads, mctx_cur->get_size());

But, if d_inner is always the same as n_heads*head_dim, I'm happy to revert this change cause it's unnecessary.

cb(ssm_states_all, "mamba_ssm_states", il);

ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * ggml_element_size(x), head_dim * n_heads * ggml_element_size(x), head_dim * n_heads * n_seq_tokens * ggml_element_size(x), 0);
cb(y, "mamba_y_view", il);
Expand Down
Loading