Skip to content

Commit 6d86944

Browse files
committed
working through previous attemp, implimented more accurate conversion per previous attempt, added local sliding window attention that alternates every third layer
1 parent ca353d3 commit 6d86944

File tree

4 files changed

+101
-125
lines changed

4 files changed

+101
-125
lines changed

convert_hf_to_gguf.py

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8308,37 +8308,32 @@ def prepare_tensors(self):
83088308
raise ValueError(f"Unprocessed experts: {experts}")
83098309

83108310

8311-
@ModelBase.register("ModernBertModel")
8312-
class ModernBertModel(TextModel):
8311+
@ModelBase.register("ModernBertModel", "ModernBertForMaskedLM", "ModernBertForSequenceClassification")
8312+
class ModernBertModel(BertModel):
83138313
model_arch = gguf.MODEL_ARCH.MODERN_BERT
83148314

8315-
def set_gguf_parameters(self) -> None:
8316-
# Determine block count (number of hidden layers)
8317-
block_count = self.hparams.get("num_hidden_layers") or self.hparams.get("num_hidden_layers_alt")
8318-
if block_count is None:
8319-
raise ValueError("Could not determine number of hidden layers from hparams")
8315+
def set_vocab(self):
8316+
self._set_vocab_gpt2()
8317+
self.gguf_writer.add_add_bos_token(True)
8318+
self.gguf_writer.add_add_eos_token(True)
83208319

8321-
# Attention heads and dimensions
8322-
n_head = self.hparams.get("num_attention_heads")
8323-
if n_head is None:
8324-
raise ValueError("Missing 'num_attention_heads' in hparams")
8320+
def set_gguf_parameters(self):
8321+
super().set_gguf_parameters()
8322+
self.gguf_writer.add_sliding_window(self.hparams["local_attention"])
8323+
self.gguf_writer.add_rope_freq_base(self.hparams["global_rope_theta"])
8324+
self.gguf_writer.add_rope_freq_base_swa(self.hparams["local_rope_theta"])
8325+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
8326+
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
83258327

8326-
hidden_size = self.hparams["hidden_size"]
8327-
head_dim = hidden_size // n_head
8328-
ffn_dim = self.hparams.get("intermediate_size", 4 * hidden_size)
8328+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
8329+
# These layers act as MLM head, so we don't need them
8330+
if name.startswith("decoder."):
8331+
return []
83298332

8330-
# GGUF parameter assignment
8331-
self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 512))
8332-
self.gguf_writer.add_embedding_length(hidden_size)
8333-
self.gguf_writer.add_feed_forward_length(ffn_dim)
8334-
self.gguf_writer.add_block_count(block_count)
8335-
self.gguf_writer.add_head_count(n_head)
8336-
self.gguf_writer.add_layer_norm_eps(self.hparams.get("layer_norm_eps", 1e-12))
8337-
self.gguf_writer.add_file_type(self.ftype)
8333+
if name.startswith("model."):
8334+
name = name[6:]
83388335

8339-
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
8340-
# Directly map tensor names without QKV splitting or reordering
8341-
return [(self.map_tensor_name(name), data_torch)]
8336+
return super().modify_tensors(data_torch, name, bid)
83428337

83438338

83448339
###### CONVERSION LOGIC ######

src/llama-hparams.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ enum llama_swa_type {
1919
LLAMA_SWA_TYPE_NONE = 0,
2020
LLAMA_SWA_TYPE_STANDARD = 1,
2121
LLAMA_SWA_TYPE_CHUNKED = 2,
22+
LLAMA_SWA_TYPE_LOCAL = 3,
2223
};
2324

2425
struct llama_hparams_posnet {

src/llama-kv-cache-unified.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1807,6 +1807,18 @@ bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
18071807
return true;
18081808
}
18091809
} break;
1810+
case LLAMA_SWA_TYPE_LOCAL:
1811+
{
1812+
const int32_t half_n_swa = (int32_t) n_swa / 2;
1813+
const int32_t pos_diff = p1 - p0;
1814+
1815+
// mask if outside the window
1816+
if (pos_diff < -half_n_swa || pos_diff > half_n_swa) {
1817+
return true;
1818+
}
1819+
} break;
1820+
1821+
18101822
}
18111823

18121824
return false;

src/llama-model.cpp

Lines changed: 68 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -759,11 +759,20 @@ void llama_model::load_hparams(llama_model_loader & ml) {
759759
} break;
760760
case LLM_ARCH_MODERN_BERT:
761761
{
762-
//ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
763-
LLAMA_LOG_INFO("Switching Modern Bert Arch\n");
762+
763+
hparams.swa_type = LLAMA_SWA_TYPE_LOCAL;
764+
765+
hparams.set_swa_pattern(3, 0);
766+
hparams.n_swa = 128;
767+
768+
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
769+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
770+
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
771+
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
772+
764773
switch (hparams.n_layer) {
765774
case 12:
766-
type = LLM_TYPE_47M; break; // granite-embeddings-mall
775+
type = LLM_TYPE_47M; break; // granite-embeddings-small
767776
default: type = LLM_TYPE_UNKNOWN;
768777
}
769778
} break;
@@ -7544,152 +7553,111 @@ struct llm_build_bert : public llm_graph_context {
75447553
struct llm_build_modern_bert : public llm_graph_context {
75457554
llm_build_modern_bert(const llama_model & model, const llm_graph_params & params)
75467555
: llm_graph_context(params) {
7547-
const int64_t n_embd = hparams.n_embd;
7548-
const int64_t n_layer = hparams.n_layer;
7549-
const int64_t n_head = hparams.n_head();
7550-
const int64_t n_head_kv = hparams.n_head_kv();
7551-
const int64_t n_embd_head = hparams.n_embd_head_v;
7552-
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
7553-
const int64_t n_tokens = ubatch.n_tokens;
7556+
const int64_t n_embd = hparams.n_embd;
7557+
const int64_t n_layer = hparams.n_layer;
7558+
const int64_t n_head = hparams.n_head();
7559+
const int64_t n_head_kv = hparams.n_head_kv();
7560+
const int64_t n_embd_head = hparams.n_embd_head_v;
7561+
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
7562+
const int64_t n_tokens = ubatch.n_tokens;
75547563

75557564
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
75567565

7557-
// RoPE params
7558-
const int32_t rope_type = LLAMA_ROPE_TYPE_NEOX; // uses rotary
7559-
const int32_t n_rot = hparams.n_rot;
7560-
const int32_t n_ctx_orig = hparams.n_ctx_train;
7561-
7562-
ggml_tensor * cur;
7563-
ggml_tensor * inpL;
7564-
ggml_tensor * inp_pos = nullptr;
7565-
7566-
// needs positions for RoPE
7567-
inp_pos = build_inp_pos();
7566+
// rope params
7567+
const int32_t rope_type = LLAMA_ROPE_TYPE_NEOX;
7568+
const int32_t n_rot = hparams.n_rot;
7569+
const int32_t n_ctx_orig = hparams.n_ctx_train;
7570+
const float freq_base = hparams.rope_freq_base_train;
7571+
const float freq_scale = hparams.rope_freq_scale_train;
7572+
const float attn_factor = 1.0f;
7573+
const float ext_factor = 1.0f;
7574+
const float beta_fast = 0.0f;
7575+
const float beta_slow = 0.0f;
75687576

7569-
// embeddings (token + optional type), NO absolute pos embed
7570-
inpL = build_inp_embd(model.tok_embd);
7577+
ggml_tensor * inp_pos = build_inp_pos();
7578+
ggml_tensor * inpL = build_inp_embd(model.tok_embd);
75717579

75727580
if (model.type_embd) {
7573-
ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0);
7574-
inpL = ggml_add(ctx0, inpL, type_row0);
7581+
inpL = ggml_add(ctx0, inpL, ggml_view_1d(ctx0, model.type_embd, n_embd, 0));
75757582
}
7576-
cb(inpL, "inp_embd", -1);
7577-
7578-
// embeddings LayerNorm (embeddings.norm)
75797583
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
7580-
cb(inpL, "inp_norm", -1);
75817584

7582-
auto * inp_attn = build_attn_inp_no_cache();
7585+
auto * inp_attn = build_attn_inp_no_cache();
75837586
ggml_tensor * inp_out_ids = build_inp_out_ids();
75847587

75857588
for (int il = 0; il < n_layer; ++il) {
75867589
ggml_tensor * x = inpL;
75877590

7588-
// pre attention norm (attn_norm). Layer 0 may be Identity() -> nullptr
7591+
// Pre attention Layer norm
75897592
ggml_tensor * x_attn_in = x;
75907593
if (model.layers[il].attn_norm) {
7591-
x_attn_in = build_norm(x,
7592-
model.layers[il].attn_norm,
7593-
model.layers[il].attn_norm_b,
7594-
LLM_NORM, il);
7595-
cb(x_attn_in, "attn_pre_norm", il);
7596-
} else {
7597-
cb(x_attn_in, "attn_pre_norm_identity", il);
7594+
x_attn_in = build_norm(x, model.layers[il].attn_norm, model.layers[il].attn_norm_b, LLM_NORM, il);
75987595
}
75997596

7600-
// Attention: fused Wqkv -> split -> heads -> RoPE(Q,K) -> attn -> Wo
7601-
ggml_tensor * qkv = nullptr;
7602-
ggml_tensor * Qcur;
7603-
ggml_tensor * Kcur;
7604-
ggml_tensor * Vcur;
7605-
7606-
GGML_ASSERT(model.layers[il].wqkv); // fused QKV
7607-
qkv = build_lora_mm(model.layers[il].wqkv, x_attn_in);
7608-
cb(qkv, "wqkv", il);
7609-
7597+
// fused qkv
7598+
GGML_ASSERT(model.layers[il].wqkv);
7599+
ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, x_attn_in);
76107600
if (model.layers[il].bqkv) {
76117601
qkv = ggml_add(ctx0, qkv, model.layers[il].bqkv);
7612-
cb(qkv, "bqkv", il);
76137602
}
76147603

7615-
// Fused layout: [ (n_embd + 2*n_embd_gqa), n_tokens ]
7616-
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd, n_tokens, qkv->nb[1], 0*sizeof(float)*(n_embd)));
7617-
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_gqa, n_tokens, qkv->nb[1], 1*sizeof(float)*(n_embd)));
7618-
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_gqa, n_tokens, qkv->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
7604+
ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd, n_tokens, qkv->nb[1], 0));
7605+
ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_gqa, n_tokens, qkv->nb[1], n_embd));
7606+
ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_gqa, n_tokens, qkv->nb[1], n_embd + n_embd_gqa));
76197607

7620-
// optional per Q/K
7621-
if (model.layers[il].attn_q_norm) {
7622-
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, LLM_NORM, il);
7623-
}
7624-
if (model.layers[il].attn_k_norm) {
7625-
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, model.layers[il].attn_k_norm_b, LLM_NORM, il);
7626-
}
7608+
// optional q/k LayerNorm
7609+
if (model.layers[il].attn_q_norm) Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, LLM_NORM, il);
7610+
if (model.layers[il].attn_k_norm) Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, model.layers[il].attn_k_norm_b, LLM_NORM, il);
76277611

7628-
// heads
7612+
// reshape for multi head
76297613
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
76307614
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
76317615
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
76327616

7633-
// RoPE (NEOX ... maybe?) on Q and K
7617+
// rope embedding
76347618
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr,
7635-
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
7636-
ext_factor, attn_factor, beta_fast, beta_slow);
7619+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
7620+
ext_factor, attn_factor, beta_fast, beta_slow);
76377621
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr,
7638-
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
7639-
ext_factor, attn_factor, beta_fast, beta_slow);
7640-
7641-
cb(Qcur, "Qcur_rope", il);
7642-
cb(Kcur, "Kcur_rope", il);
7643-
cb(Vcur, "Vcur", il);
7622+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
7623+
ext_factor, attn_factor, beta_fast, beta_slow);
76447624

76457625
ggml_tensor * attn_out = build_attn(
76467626
inp_attn,
7647-
model.layers[il].wo, model.layers[il].bo, // Wo, optional bias
7627+
model.layers[il].wo, model.layers[il].bo,
76487628
Qcur, Kcur, Vcur,
7649-
/*K_cache*/ nullptr,
7650-
/*V_cache*/ nullptr,
7629+
/*k cache*/ nullptr,
7630+
/*v cache*/ nullptr,
76517631
1.0f / sqrtf(float(n_embd_head)),
7652-
il);
7653-
cb(attn_out, "attn_out", il);
7632+
il
7633+
);
76547634

7655-
// residual after attention
76567635
ggml_tensor * cur_attn = ggml_add(ctx0, attn_out, x);
76577636

7658-
// ifwe subselect outputs, do it at the last layer after attn resid
7637+
// optional subselect output tokens (inp_out_ids)
76597638
if (il == n_layer - 1 && inp_out_ids) {
7660-
cur_attn = ggml_get_rows(ctx0, cur_attn, inp_out_ids);
7661-
x = ggml_get_rows(ctx0, x, inp_out_ids);
7639+
cur_attn = ggml_get_rows(ctx0, cur_attn, inp_out_ids);
7640+
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
76627641
}
76637642

7664-
// pre mlp norm
7665-
ggml_tensor * h = build_norm(cur_attn,
7666-
model.layers[il].ffn_norm,
7667-
model.layers[il].ffn_norm_b,
7668-
LLM_NORM, il);
7669-
cb(h, "mlp_pre_norm", il);
7643+
// pre mlp LayerNorm
7644+
ggml_tensor * h = build_norm(cur_attn, model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, LLM_NORM, il);
76707645

7671-
// GEGLU because we will split ffn_up which has shape [n_embd, n_ff * 2] and ffn_down has shape [n_ff, n_embd]
7646+
// geglu FFN
76727647
ggml_tensor * mlp_out = build_ffn(
76737648
h,
7674-
model.layers[il].ffn_up, /*up_b*/ NULL, /*up_shexp*/ NULL,
7675-
/*gate*/ NULL , /*gate_b*/ NULL, /*gate_shexp*/ NULL,
7676-
model.layers[il].ffn_down, /*down_b*/ NULL, /*down_shexp*/ NULL,
7677-
/*act_scales*/ NULL,
7649+
model.layers[il].ffn_up, NULL, NULL,
7650+
NULL, NULL, NULL,
7651+
model.layers[il].ffn_down, NULL, NULL,
7652+
NULL,
76787653
LLM_FFN_GEGLU, LLM_FFN_PAR, il
76797654
);
76807655

7681-
cb(mlp_out, "ffn_out_geglu", il);
7682-
// Residual after MLP
7683-
ggml_tensor * cur_layer = ggml_add(ctx0, mlp_out, cur_attn);
7684-
7685-
// feed into next layer
7686-
inpL = cur_layer;
7656+
// resid addition
7657+
inpL = ggml_add(ctx0, mlp_out, cur_attn);
76877658
}
76887659

7689-
// final model norm (final_norm)
7690-
cur = build_norm(inpL, model.output_norm, model.output_norm_b, LLM_NORM, -1);
7691-
cb(cur, "final_norm", -1);
7692-
7660+
ggml_tensor * cur = build_norm(inpL, model.output_norm, model.output_norm_b, LLM_NORM, -1);
76937661
res->t_embd = cur;
76947662
ggml_build_forward_expand(gf, cur);
76957663
}

0 commit comments

Comments
 (0)