Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4185,7 +4185,8 @@ def set_gguf_parameters(self):
# This logic matches modeling_plamo.py's is_mamba function
mamba_step = hparams.get("mamba_step", 2)
mamba_enabled = hparams.get("mamba_enabled", True)
mamba_layers = []
num_key_value_heads = []
num_attention_heads = []

if mamba_enabled:
for i in range(block_count):
Expand All @@ -4195,17 +4196,21 @@ def set_gguf_parameters(self):
else:
is_mamba = (i % mamba_step) != (mamba_step // 2)
if is_mamba:
mamba_layers.append(0)
num_key_value_heads.append(0)
num_attention_heads.append(0)
else:
mamba_layers.append(hparams.get("num_key_value_heads", 4))
num_key_value_heads.append(hparams.get("num_key_value_heads", 4))
num_attention_heads.append(hparams.get("num_attention_heads", 32))

if mamba_layers:
self.gguf_writer.add_head_count_kv(mamba_layers)
if num_key_value_heads and num_attention_heads:
self.gguf_writer.add_head_count_kv(num_key_value_heads)
self.gguf_writer.add_head_count(num_attention_heads)

self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 2048))
self.gguf_writer.add_embedding_length(hparams.get("hidden_size", 4096))
self.gguf_writer.add_key_length(hparams.get("hidden_size_per_head", 128))
self.gguf_writer.add_value_length(hparams.get("hidden_size_per_head", 128))
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 32))
self.gguf_writer.add_layer_norm_rms_eps(hparams.get("rms_norm_eps", 1e-06))
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 10000))

Expand Down
2 changes: 1 addition & 1 deletion src/llama-hparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ struct llama_hparams {
uint32_t n_embd;
uint32_t n_embd_features = 0;
uint32_t n_layer;
int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache
int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache
uint32_t n_rot;
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
Expand Down
28 changes: 17 additions & 11 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
break;
default: type = LLM_TYPE_UNKNOWN;
}

// Load attention parameters
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false);
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Load attention parameters
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false);
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);

Already done here:

// non-transformer models do not have attention heads
if (hparams.n_head() > 0) {
// gpt-neox n_rot = rotary_pct * (n_embd / n_head)
// gpt-j n_rot = rotary_dim
hparams.n_embd_head_k = hparams.n_embd / hparams.n_head();
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false);
hparams.n_embd_head_v = hparams.n_embd / hparams.n_head();
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
// sanity check for n_rot (optional)
hparams.n_rot = hparams.n_embd_head_k;
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON) {
if (hparams.n_rot != hparams.n_embd_head_k) {
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
}
}
} else {
hparams.n_rot = 0;
hparams.n_embd_head_k = 0;
hparams.n_embd_head_v = 0;
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CISC Ah, I applied your suggestion but noticed that the problem comes from the following two lines:

hparams.n_embd_head_k = hparams.n_embd / hparams.n_head();

hparams.n_embd_head_v = hparams.n_embd / hparams.n_head();

The main purpose of this PR is to support some variations of PLaMo2 model created by pruning larger models, that has its n_embed_head_k larger than n_embd / n_head.

So let me roll back this changes to support such cases in a variant of PLaMo2 models.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main purpose of this PR is to support some variations of PLaMo2 model created by pruning larger models, that has its n_embed_head_k larger than n_embd / n_head.

So let me roll back this changes to support such cases in a variant of PLaMo2 models.

But they would have to have the metadata present to work then, so no issue.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mitmul gentle ping

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CISC Sorry for my late reaction. Um, could you explain what you meant with this?

But they would have to have the metadata present to work then, so no issue.

For example, pfnet/plamo-2.1-2b-cpt has

  • hidden_size = 2048 (for n_embd)
  • hidden_size_per_head = 128 (for n_embd_head_k and n_embd_head_v)
  • num_key_value_heads = 32 (for n_head_arr in il % 2 = 0 layers (attention layers, not mamba layers)),

so that n_embd_head_k/v != n_embd / n_head.
Then, with the current load_hparams() it goes through the else block here:

} else {
hparams.n_rot = 0;
hparams.n_embd_head_k = 0;
hparams.n_embd_head_v = 0;
}

because hparams.n_head() will return 0 that comes from the first element of n_head_arr indicating mamba layers.

So, I thought we need to call the following functions to set hparams.n_embd_head_k/v with the right values comes from hidden_size_per_head in config.json:

llama.cpp/src/llama-model.cpp

Lines 1082 to 1083 in 1be2787

ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false);
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);

That's why I sent this PR, but if I wrongly understand something around model loading, please let me know it and I'd appreciate it if you would give me some advices about how we can load such variants like pfnet/plamo-2.1-2b-cpt.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha, that makes sense then, thank you for the explanation. :)

} break;
case LLM_ARCH_GPT2:
{
Expand Down Expand Up @@ -3367,17 +3371,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
} break;
case LLM_ARCH_PLAMO2:
{
const uint32_t d_conv = hparams.ssm_d_conv;
const uint32_t d_state = hparams.ssm_d_state;
const uint32_t num_heads = hparams.ssm_dt_rank;
const uint32_t intermediate_size = hparams.ssm_d_inner;
const uint32_t head_dim = intermediate_size / num_heads;
const uint32_t qk_dim = head_dim;
const uint32_t v_dim = head_dim;
const int64_t num_attention_heads = hparams.n_head();
const int64_t q_num_heads = num_attention_heads;
const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16));

tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);

// output
Expand All @@ -3392,9 +3385,16 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
auto & layer = layers[i];
bool is_mamba_layer = hparams.is_recurrent(i);


layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);

if (is_mamba_layer) {
const uint32_t d_conv = hparams.ssm_d_conv;
const uint32_t d_state = hparams.ssm_d_state;
const uint32_t num_heads = hparams.ssm_dt_rank;
const uint32_t intermediate_size = hparams.ssm_d_inner;
const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16));

layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2 * intermediate_size}, 0);
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, intermediate_size}, 0);

Expand All @@ -3411,6 +3411,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, i), {d_state}, 0);
layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, i), {d_state}, 0);
} else {
const uint32_t head_dim = hparams.n_embd_head_k;
const uint32_t qk_dim = head_dim;
const uint32_t v_dim = head_dim;
const int64_t num_attention_heads = hparams.n_head(i);
const int64_t q_num_heads = num_attention_heads;
const int64_t num_key_value_heads = hparams.n_head_kv(i);
const int64_t k_num_heads = num_key_value_heads;
const int64_t v_num_heads = num_key_value_heads;
Expand Down Expand Up @@ -17520,6 +17525,7 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
const int64_t n_embd_head_q = hparams.n_embd_head_k;
const int64_t n_embd_head_k = hparams.n_embd_head_k;
const int64_t n_embd_head_v = hparams.n_embd_head_v;
int32_t n_head = hparams.n_head(il);
int32_t n_head_kv = hparams.n_head_kv(il);

const int64_t q_offset = 0;
Expand Down
Loading