Skip to content

Commit e101005

Browse files
committed
working on swa with local and global alternating attention
1 parent 39c0291 commit e101005

File tree

3 files changed

+113
-19
lines changed

3 files changed

+113
-19
lines changed

src/llama-arch.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
171171
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
172172
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
173173
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
174+
{ LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" },
174175
{ LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" },
175176
{ LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" },
176177
{ LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" },

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ enum llm_kv {
176176
LLM_KV_ROPE_DIMENSION_SECTIONS,
177177
LLM_KV_ROPE_FREQ_BASE,
178178
LLM_KV_ROPE_SCALE_LINEAR,
179+
LLM_KV_ROPE_FREQ_BASE_SWA,
179180
LLM_KV_ROPE_SCALING_TYPE,
180181
LLM_KV_ROPE_SCALING_FACTOR,
181182
LLM_KV_ROPE_SCALING_ATTN_FACTOR,

src/llama-model.cpp

Lines changed: 111 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7559,6 +7559,7 @@ struct llm_build_modern_bert : public llm_graph_context {
75597559
const int64_t n_head_kv = hparams.n_head_kv();
75607560
const int64_t n_embd_head = hparams.n_embd_head_v;
75617561
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
7562+
const int64_t n_local_swa = hparams.n_swa;
75627563
const int64_t n_tokens = ubatch.n_tokens;
75637564

75647565
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -7574,27 +7575,38 @@ struct llm_build_modern_bert : public llm_graph_context {
75747575
const float beta_fast = 0.0f;
75757576
const float beta_slow = 0.0f;
75767577

7577-
ggml_tensor * inp_pos = build_inp_pos();
7578+
7579+
ggml_tensor *inp_pos_global = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, 4096, 1);
7580+
ggml_set_input(inp_pos_global);
7581+
size_t element_size = ggml_type_size(inp_pos_global->type);
7582+
7583+
size_t nb1 = element_size;
7584+
size_t nb2 = nb1;
7585+
7586+
inp_pos_global = ggml_view_3d(ctx0, inp_pos_global, 1, 1, 4096, nb1, nb2, 0);
7587+
inp_pos_global = ggml_cont(ctx0, inp_pos_global);
7588+
75787589
ggml_tensor * inpL = build_inp_embd(model.tok_embd);
75797590

75807591
if (model.type_embd) {
75817592
inpL = ggml_add(ctx0, inpL, ggml_view_1d(ctx0, model.type_embd, n_embd, 0));
75827593
}
75837594
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
75847595

7585-
auto * inp_attn = build_attn_inp_no_cache();
7596+
auto * inp_attn = build_attn_inp_kv_unified_iswa();
75867597
ggml_tensor * inp_out_ids = build_inp_out_ids();
75877598

7599+
75887600
for (int il = 0; il < n_layer; ++il) {
75897601
ggml_tensor * x = inpL;
75907602

7591-
// Pre attention Layer norm
7603+
// pre attn LayerNorm
75927604
ggml_tensor * x_attn_in = x;
75937605
if (model.layers[il].attn_norm) {
75947606
x_attn_in = build_norm(x, model.layers[il].attn_norm, model.layers[il].attn_norm_b, LLM_NORM, il);
75957607
}
75967608

7597-
// fused qkv
7609+
// fused QKV
75987610
GGML_ASSERT(model.layers[il].wqkv);
75997611
ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, x_attn_in);
76007612
if (model.layers[il].bqkv) {
@@ -7609,41 +7621,120 @@ struct llm_build_modern_bert : public llm_graph_context {
76097621
if (model.layers[il].attn_q_norm) Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, LLM_NORM, il);
76107622
if (model.layers[il].attn_k_norm) Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, model.layers[il].attn_k_norm_b, LLM_NORM, il);
76117623

7612-
// reshape for multi head
7624+
// reshape for multi-head attention
76137625
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
76147626
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
76157627
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
76167628

7617-
// rope embedding
7618-
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr,
7619-
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
7620-
ext_factor, attn_factor, beta_fast, beta_slow);
7621-
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr,
7622-
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
7629+
// global or local layer
7630+
bool is_global = ((il + 1) % 3 == 0);
7631+
float freq_base_l = is_global ? 160000.0f : 10000.0f; // rope theta
7632+
float freq_scale_l = 1.0f;
7633+
7634+
ggml_tensor * pos_q = inp_pos_global;
7635+
7636+
ggml_tensor * K_work = Kcur;
7637+
ggml_tensor * V_work = Vcur;
7638+
ggml_tensor * pos_k = inp_pos_global;
7639+
7640+
if (!is_global) {
7641+
ggml_tensor * idx_src = inp_attn->self_k_idxs_swa;
7642+
7643+
ggml_tensor * idx_view1d = ggml_view_1d(ctx0, idx_src, idx_src->ne[0], 0);
7644+
ggml_tensor * idx_cont = ggml_cont(ctx0, idx_view1d);
7645+
7646+
ggml_tensor * idx_i32 = idx_cont;
7647+
if (idx_i32->type != GGML_TYPE_I32) {
7648+
idx_i32 = ggml_cast(ctx0, idx_cont, GGML_TYPE_I32);
7649+
}
7650+
7651+
const int64_t n_indices = idx_i32->ne[0];
7652+
ggml_tensor * idx_2d = ggml_view_2d(ctx0, idx_i32, 1, n_indices, sizeof(int32_t), 0);
7653+
7654+
idx_2d = ggml_cont(ctx0, idx_2d);
7655+
if (idx_2d->type != GGML_TYPE_I32) idx_2d = ggml_cast(ctx0, idx_2d, GGML_TYPE_I32);
7656+
7657+
Kcur->ne[0], Kcur->ne[1], Kcur->ne[2],
7658+
idx_2d->ne[0], idx_2d->ne[1], idx_2d->ne[2], idx_2d->ne[3],
7659+
idx_2d->type);
7660+
7661+
K_work = ggml_get_rows(ctx0, Kcur, idx_2d);
7662+
V_work = ggml_get_rows(ctx0, Vcur, idx_2d);
7663+
7664+
7665+
7666+
ggml_tensor * pos_rows = ggml_get_rows(ctx0, inp_pos_global, idx_2d);
7667+
7668+
if (!ggml_is_vector(pos_rows)) {
7669+
const int64_t n_el = ggml_nelements(pos_rows);
7670+
pos_rows = ggml_view_1d(ctx0, pos_rows, n_el, 0);
7671+
pos_rows = ggml_cont(ctx0, pos_rows);
7672+
} else {
7673+
pos_rows = ggml_cont(ctx0, pos_rows);
7674+
}
7675+
// ensure I32
7676+
if (pos_rows->type != GGML_TYPE_I32) {
7677+
pos_rows = ggml_cast(ctx0, pos_rows, GGML_TYPE_I32);
7678+
}
7679+
7680+
// final pos_k to pass to rope
7681+
pos_k = pos_rows;
7682+
LLAMA_LOG_INFO("pos_k final: ne[0]=%lld, type=%d\n", pos_k->ne[0], pos_k->type);
7683+
}
7684+
7685+
if( !ggml_is_vector(pos_q) ) {
7686+
const int64_t n_el = ggml_nelements(pos_q);
7687+
pos_q = ggml_view_1d(ctx0, pos_q, n_el, 0);
7688+
pos_q = ggml_cont(ctx0, pos_q);
7689+
}
7690+
if( !ggml_is_vector(pos_q) ) {
7691+
}
7692+
7693+
7694+
// apply rope
7695+
Qcur = ggml_rope_ext(ctx0, Qcur, pos_q, nullptr,
7696+
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
7697+
ext_factor, attn_factor, beta_fast, beta_slow);
7698+
7699+
if( !ggml_is_vector(pos_k) ) {
7700+
const int64_t n_el = ggml_nelements(pos_k);
7701+
pos_k = ggml_view_1d(ctx0, pos_k, n_el, 0);
7702+
pos_k = ggml_cont(ctx0, pos_k);
7703+
}
7704+
7705+
K_work = ggml_rope_ext(ctx0, K_work, pos_k, nullptr,
7706+
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
76237707
ext_factor, attn_factor, beta_fast, beta_slow);
76247708

7709+
// choseing mask, global vs swa
7710+
ggml_tensor * kq_b_layer = is_global ? inp_attn->self_kq_mask : inp_attn->self_kq_mask_swa;
7711+
76257712
ggml_tensor * attn_out = build_attn(
76267713
inp_attn,
7627-
model.layers[il].wo, model.layers[il].bo,
7628-
Qcur, Kcur, Vcur,
7629-
/*k cache*/ nullptr,
7630-
/*v cache*/ nullptr,
7714+
model.layers[il].wo,
7715+
model.layers[il].bo,
7716+
Qcur,
7717+
K_work,
7718+
V_work,
7719+
kq_b_layer,
7720+
nullptr,
76317721
1.0f / sqrtf(float(n_embd_head)),
76327722
il
76337723
);
76347724

7725+
// residual addition
76357726
ggml_tensor * cur_attn = ggml_add(ctx0, attn_out, x);
76367727

7637-
// optional subselect output tokens (inp_out_ids)
7728+
// optional output select
76387729
if (il == n_layer - 1 && inp_out_ids) {
76397730
cur_attn = ggml_get_rows(ctx0, cur_attn, inp_out_ids);
76407731
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
76417732
}
76427733

7643-
// pre mlp LayerNorm
7734+
// pre mlp layer norm
76447735
ggml_tensor * h = build_norm(cur_attn, model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, LLM_NORM, il);
76457736

7646-
// geglu FFN
7737+
// geglu ffn
76477738
ggml_tensor * mlp_out = build_ffn(
76487739
h,
76497740
model.layers[il].ffn_up, NULL, NULL,
@@ -7653,10 +7744,11 @@ struct llm_build_modern_bert : public llm_graph_context {
76537744
LLM_FFN_GEGLU, LLM_FFN_PAR, il
76547745
);
76557746

7656-
// resid addition
7747+
// resudi addition after FFN
76577748
inpL = ggml_add(ctx0, mlp_out, cur_attn);
76587749
}
76597750

7751+
76607752
ggml_tensor * cur = build_norm(inpL, model.output_norm, model.output_norm_b, LLM_NORM, -1);
76617753
res->t_embd = cur;
76627754
ggml_build_forward_expand(gf, cur);

0 commit comments

Comments
 (0)