1- #include " ../llama-model.h"
1+ #include " llm_build_lfm2.h"
2+
23#include " ../llama-graph.h"
4+ #include " ../llama-model.h"
5+ #include " ../llama-memory-hybrid.h"
36
4- #include " llm_build_lfm2.h"
57#include < cmath>
68
7- llm_build_lfm2::llm_build_lfm2 (const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
9+ llm_build_lfm2::llm_build_lfm2 (const llama_model & model, const llm_graph_params & params) :
10+ llm_graph_context(params),
11+ model(model) {
812 ggml_tensor * cur = build_inp_embd (model.tok_embd );
913 cb (cur, " model.embed_tokens" , -1 );
1014
1115 ggml_tensor * inp_pos = build_inp_pos ();
12- auto * inp_hybrid = build_inp_mem_hybrid ();
16+ auto * inp_hybrid = build_inp_mem_hybrid ();
1317 ggml_tensor * inp_out_ids = build_inp_out_ids ();
1418
1519 for (int il = 0 ; il < n_layer; ++il) {
1620 auto * prev_cur = cur;
17- cur = build_norm (cur, model.layers [il].attn_norm , NULL , LLM_NORM_RMS, il);
21+ cur = build_norm (cur, model.layers [il].attn_norm , NULL , LLM_NORM_RMS, il);
1822 cb (cur, " model.layers.{}.operator_norm" , il);
1923
20- // TODO: implement recurrent/attention logic inline
21- // cur = hparams.is_recurrent(il) ?
22- // build_shortconv_block(cur, inp_hybrid->get_recr(), il) :
23- // build_attn_block(cur, inp_pos, inp_hybrid->get_attn(), il) ;
24+ cur = hparams.is_recurrent (il) ? build_shortconv_block (cur, inp_hybrid->get_recr (), il) :
25+ build_attn_block (cur, inp_pos, inp_hybrid->get_attn (), il);
2426
2527 if (il == n_layer - 1 && inp_out_ids) {
26- cur = ggml_get_rows (ctx0, cur, inp_out_ids);
28+ cur = ggml_get_rows (ctx0, cur, inp_out_ids);
2729 prev_cur = ggml_get_rows (ctx0, prev_cur, inp_out_ids);
2830 }
29- ;
31+
3032 cur = ggml_add (ctx0, prev_cur, cur);
31- // TODO: implement feed_forward inline
32- // cur = ggml_add(ctx0, cur, build_feed_forward(cur, il));
33+ cur = ggml_add (ctx0, cur, build_feed_forward (cur, il));
3334 }
34- ;
35+
3536 cur = build_norm (cur, model.tok_norm , NULL , LLM_NORM_RMS, -1 );
3637 cb (cur, " model.embedding_norm" , -1 );
3738 res->t_embd = cur;
@@ -43,4 +44,117 @@ llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_params
4344
4445 ggml_build_forward_expand (gf, cur);
4546}
46- ;
47+
48+ ggml_tensor * llm_build_lfm2::build_feed_forward (ggml_tensor * cur, int il) const {
49+ cur = build_norm (cur, model.layers [il].ffn_norm , NULL , LLM_NORM_RMS, il);
50+ cb (cur, " model.layers.{}.ffn_norm" , il);
51+
52+ GGML_ASSERT (!model.layers [il].ffn_up_b );
53+ GGML_ASSERT (!model.layers [il].ffn_gate_b );
54+ GGML_ASSERT (!model.layers [il].ffn_down_b );
55+ cur = build_ffn (cur, model.layers [il].ffn_up , NULL , NULL , model.layers [il].ffn_gate , NULL , NULL ,
56+ model.layers [il].ffn_down , NULL , NULL , NULL , LLM_FFN_SILU, LLM_FFN_PAR, il);
57+ cb (cur, " model.layers.{}.feed_forward.w2" , il);
58+
59+ return cur;
60+ }
61+
62+ ggml_tensor * llm_build_lfm2::build_attn_block (ggml_tensor * cur,
63+ ggml_tensor * inp_pos,
64+ llm_graph_input_attn_kv * inp_attn,
65+ int il) const {
66+ GGML_ASSERT (hparams.n_embd_v_gqa (il) == hparams.n_embd_k_gqa (il));
67+ const auto n_embd_head = hparams.n_embd_head_v ;
68+ const auto n_head_kv = hparams.n_head_kv (il);
69+
70+ auto * q = build_lora_mm (model.layers [il].wq , cur);
71+ cb (q, " model.layers.{}.self_attn.q_proj" , il);
72+ auto * k = build_lora_mm (model.layers [il].wk , cur);
73+ cb (k, " model.layers.{}.self_attn.k_proj" , il);
74+ auto * v = build_lora_mm (model.layers [il].wv , cur);
75+ cb (v, " model.layers.{}.self_attn.v_proj" , il);
76+
77+ q = ggml_reshape_3d (ctx0, q, n_embd_head, n_head, n_tokens);
78+ k = ggml_reshape_3d (ctx0, k, n_embd_head, n_head_kv, n_tokens);
79+ v = ggml_reshape_3d (ctx0, v, n_embd_head, n_head_kv, n_tokens);
80+
81+ // qk norm
82+ q = build_norm (q, model.layers [il].attn_q_norm , NULL , LLM_NORM_RMS, il);
83+ cb (q, " model.layers.{}.self_attn.q_layernorm" , il);
84+ k = build_norm (k, model.layers [il].attn_k_norm , NULL , LLM_NORM_RMS, il);
85+ cb (k, " model.layers.{}.self_attn.k_layernorm" , il);
86+
87+ // RoPE
88+ q = ggml_rope_ext (ctx0, q, inp_pos, nullptr , n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
89+ attn_factor, beta_fast, beta_slow);
90+ k = ggml_rope_ext (ctx0, k, inp_pos, nullptr , n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
91+ attn_factor, beta_fast, beta_slow);
92+
93+ cur = build_attn (inp_attn, model.layers [il].wo , NULL , q, k, v, nullptr , nullptr , nullptr ,
94+ 1 .0f / sqrtf (float (n_embd_head)), il);
95+
96+ cb (cur, " model.layers.{}.self_attn.out_proj" , il);
97+
98+ return cur;
99+ }
100+
101+ ggml_tensor * llm_build_lfm2::build_shortconv_block (ggml_tensor * cur, llm_graph_input_rs * inp_recr, int il) {
102+ const auto * mctx_cur = static_cast <const llama_memory_hybrid_context *>(mctx)->get_recr ();
103+ const uint32_t kv_head = mctx_cur->get_head ();
104+ const int64_t n_seq_tokens = ubatch.n_seq_tokens ;
105+ const int64_t n_seqs = ubatch.n_seqs ;
106+ GGML_ASSERT (n_seqs != 0 );
107+ GGML_ASSERT (ubatch.equal_seqs ());
108+ GGML_ASSERT (ubatch.n_tokens == n_seq_tokens * n_seqs);
109+
110+ GGML_ASSERT (hparams.n_shortconv_l_cache > 1 );
111+ const uint32_t d_conv = hparams.n_shortconv_l_cache - 1 ;
112+
113+ // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
114+ cur = ggml_reshape_3d (ctx0, cur, cur->ne [0 ], n_seq_tokens, n_seqs);
115+
116+ auto * bcx = build_lora_mm (model.layers [il].shortconv .in_proj , cur);
117+ cb (bcx, " model.layers.{}.conv.in_proj" , il);
118+
119+ constexpr auto n_chunks = 3 ;
120+ GGML_ASSERT (bcx->ne [0 ] % n_chunks == 0 );
121+ const auto chunk_size = bcx->ne [0 ] / n_chunks;
122+ auto * b = ggml_view_3d (ctx0, bcx, chunk_size, bcx->ne [1 ], bcx->ne [2 ], bcx->nb [1 ], bcx->nb [2 ],
123+ 0 * chunk_size * ggml_element_size (bcx));
124+ auto * c = ggml_view_3d (ctx0, bcx, chunk_size, bcx->ne [1 ], bcx->ne [2 ], bcx->nb [1 ], bcx->nb [2 ],
125+ 1 * chunk_size * ggml_element_size (bcx));
126+ auto * x = ggml_view_3d (ctx0, bcx, chunk_size, bcx->ne [1 ], bcx->ne [2 ], bcx->nb [1 ], bcx->nb [2 ],
127+ 2 * chunk_size * ggml_element_size (bcx));
128+
129+ auto * bx = ggml_transpose (ctx0, ggml_mul (ctx0, b, x));
130+
131+ // read conv state
132+ auto * conv_state = mctx_cur->get_r_l (il);
133+ auto * conv_rs = build_rs (inp_recr, conv_state, hparams.n_embd_r (), n_seqs);
134+ auto * conv = ggml_reshape_3d (ctx0, conv_rs, d_conv, hparams.n_embd , n_seqs);
135+
136+ bx = ggml_concat (ctx0, conv, bx, 0 );
137+ GGML_ASSERT (bx->ne [0 ] > conv->ne [0 ]);
138+
139+ // last d_conv columns is a new conv state
140+ auto * new_conv = ggml_view_3d (ctx0, bx, conv->ne [0 ], bx->ne [1 ], bx->ne [2 ], bx->nb [1 ], bx->nb [2 ],
141+ (bx->ne [0 ] - conv->ne [0 ]) * ggml_element_size (bx));
142+ GGML_ASSERT (ggml_are_same_shape (conv, new_conv));
143+
144+ // write new conv conv state
145+ ggml_build_forward_expand (gf, ggml_cpy (ctx0, new_conv,
146+ ggml_view_1d (ctx0, conv_state, ggml_nelements (new_conv),
147+ kv_head * d_conv * n_embd * ggml_element_size (new_conv))));
148+
149+ auto * conv_kernel = model.layers [il].shortconv .conv ;
150+ auto * conv_out = ggml_ssm_conv (ctx0, bx, conv_kernel);
151+ cb (conv_out, " model.layers.{}.conv.conv" , il);
152+
153+ auto * y = ggml_mul (ctx0, c, conv_out);
154+ y = build_lora_mm (model.layers [il].shortconv .out_proj , y);
155+ cb (y, " model.layers.{}.conv.out_proj" , il);
156+ // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
157+ y = ggml_reshape_2d (ctx0, y, y->ne [0 ], n_seq_tokens * n_seqs);
158+
159+ return y;
160+ }
0 commit comments