1+ #include " models.h"
2+
3+ llm_build_afmoe::llm_build_afmoe (const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
4+ const int64_t n_embd_head = hparams.n_embd_head_v ;
5+ GGML_ASSERT (n_embd_head == hparams.n_embd_head_k );
6+
7+ ggml_tensor * cur;
8+ ggml_tensor * inpL;
9+
10+ inpL = build_inp_embd (model.tok_embd );
11+
12+ // MuP scaling: embeddings * sqrt(hidden_size)
13+ // mup_enabled = true, hidden_size = 1024, scale = 32.0
14+ inpL = ggml_scale (ctx0, inpL, sqrtf (float (n_embd)));
15+ cb (inpL, " inp_embd_scaled" , -1 );
16+
17+ // inp_pos - contains the positions
18+ ggml_tensor * inp_pos = build_inp_pos ();
19+ auto * inp_attn = build_attn_inp_kv ();
20+ ggml_tensor * inp_out_ids = build_inp_out_ids ();
21+
22+ const float kq_scale = 1 .0f /sqrtf (float (n_embd_head));
23+
24+ for (int il = 0 ; il < n_layer; ++il) {
25+ ggml_tensor * inpSA = inpL;
26+
27+ // dual attention normalization (pre)
28+ cur = build_norm (inpL,
29+ model.layers [il].attn_norm , NULL ,
30+ LLM_NORM_RMS, il);
31+ cb (cur, " attn_norm" , il);
32+
33+ // self-attention
34+ {
35+ ggml_tensor * attn_inp = cur; // save input for gate computation
36+
37+ ggml_tensor * Qcur = build_lora_mm (model.layers [il].wq , cur);
38+ cb (Qcur, " Qcur" , il);
39+
40+ ggml_tensor * Kcur = build_lora_mm (model.layers [il].wk , cur);
41+ cb (Kcur, " Kcur" , il);
42+
43+ ggml_tensor * Vcur = build_lora_mm (model.layers [il].wv , cur);
44+ cb (Vcur, " Vcur" , il);
45+
46+ // compute gate from input
47+ ggml_tensor * gate = build_lora_mm (model.layers [il].wqkv_gate , attn_inp);
48+ cb (gate, " attn_gate_proj" , il);
49+
50+ Qcur = ggml_reshape_3d (ctx0, Qcur, n_embd_head, n_head, n_tokens);
51+ Kcur = ggml_reshape_3d (ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
52+
53+ // Q/K normalization
54+ Qcur = build_norm (Qcur, model.layers [il].attn_q_norm , NULL , LLM_NORM_RMS, il);
55+ Kcur = build_norm (Kcur, model.layers [il].attn_k_norm , NULL , LLM_NORM_RMS, il);
56+ cb (Qcur, " Qcur_normed" , il);
57+ cb (Kcur, " Kcur_normed" , il);
58+
59+ // RoPE only for sliding_attention layers (every 4th layer is full_attention)
60+ // layer_types[i] = "sliding_attention" if (i+1) % global_attn_every_n_layers != 0
61+ bool is_sliding = ((il + 1 ) % 4 ) != 0 ; // global_attn_every_n_layers = 4
62+ if (is_sliding) {
63+ Qcur = ggml_rope_ext (
64+ ctx0, Qcur, inp_pos, nullptr ,
65+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
66+ ext_factor, attn_factor, beta_fast, beta_slow);
67+ cb (Qcur, " Qcur_rope" , il);
68+
69+ Kcur = ggml_rope_ext (
70+ ctx0, Kcur, inp_pos, nullptr ,
71+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
72+ ext_factor, attn_factor, beta_fast, beta_slow);
73+ cb (Kcur, " Kcur_rope" , il);
74+ }
75+
76+ Vcur = ggml_reshape_3d (ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
77+
78+ cur = build_attn (inp_attn,
79+ NULL , NULL , // wo will be applied after gating
80+ Qcur, Kcur, Vcur, nullptr , nullptr , nullptr , kq_scale, il);
81+ cb (cur, " attn_out" , il);
82+
83+ // attention gating: attn_out * sigmoid(gate) BEFORE o_proj
84+ gate = ggml_sigmoid (ctx0, gate);
85+ cb (gate, " attn_gate_sig" , il);
86+ cur = ggml_mul (ctx0, cur, gate);
87+ cb (cur, " attn_gated" , il);
88+
89+ // now apply output projection
90+ cur = build_lora_mm (model.layers [il].wo , cur);
91+ cb (cur, " attn_o_proj" , il);
92+ }
93+
94+ // dual attention normalization (post)
95+ cur = build_norm (cur,
96+ model.layers [il].attn_post_norm , NULL ,
97+ LLM_NORM_RMS, il);
98+ cb (cur, " attn_post_norm" , il);
99+
100+ if (il == n_layer - 1 && inp_out_ids) {
101+ cur = ggml_get_rows (ctx0, cur, inp_out_ids);
102+ inpSA = ggml_get_rows (ctx0, inpSA, inp_out_ids);
103+ }
104+
105+ ggml_tensor * ffn_inp = ggml_add (ctx0, cur, inpSA);
106+ cb (ffn_inp, " ffn_inp" , il);
107+
108+ // dual ffn normalization (pre)
109+ cur = build_norm (ffn_inp,
110+ model.layers [il].ffn_norm , NULL ,
111+ LLM_NORM_RMS, il);
112+ cb (cur, " ffn_norm" , il);
113+
114+ // MoE or dense FFN
115+ if ((uint32_t )il >= hparams.n_layer_dense_lead ) {
116+ // MoE layer with sigmoid routing, normalization, and scaling
117+ ggml_tensor * moe_out = build_moe_ffn (cur,
118+ model.layers [il].ffn_gate_inp ,
119+ model.layers [il].ffn_up_exps ,
120+ model.layers [il].ffn_gate_exps ,
121+ model.layers [il].ffn_down_exps ,
122+ model.layers [il].ffn_exp_probs_b ,
123+ n_expert, n_expert_used,
124+ LLM_FFN_SILU,
125+ hparams.expert_weights_norm != 0 , // norm_w (route_norm=True)
126+ hparams.expert_weights_scale != 0 .0f , // scale_w
127+ hparams.expert_weights_scale , // w_scale (route_scale=2.826)
128+ (llama_expert_gating_func_type) hparams.expert_gating_func ,
129+ il);
130+ cb (moe_out, " ffn_moe_out" , il);
131+
132+ // shared expert
133+ if (hparams.n_expert_shared > 0 ) {
134+ ggml_tensor * ffn_shexp = build_ffn (cur,
135+ model.layers [il].ffn_up_shexp , NULL , NULL ,
136+ model.layers [il].ffn_gate_shexp , NULL , NULL ,
137+ model.layers [il].ffn_down_shexp , NULL , NULL ,
138+ NULL ,
139+ LLM_FFN_SILU, LLM_FFN_PAR, il);
140+ cb (ffn_shexp, " ffn_shexp" , il);
141+
142+ cur = ggml_add (ctx0, moe_out, ffn_shexp);
143+ cb (cur, " ffn_out" , il);
144+ } else {
145+ cur = moe_out;
146+ }
147+ } else {
148+ // dense layer
149+ cur = build_ffn (cur,
150+ model.layers [il].ffn_up , NULL , NULL ,
151+ model.layers [il].ffn_gate , NULL , NULL ,
152+ model.layers [il].ffn_down , NULL , NULL ,
153+ NULL ,
154+ LLM_FFN_SILU, LLM_FFN_PAR, il);
155+ cb (cur, " ffn_out" , il);
156+ }
157+
158+ // dual ffn normalization (post)
159+ cur = build_norm (cur,
160+ model.layers [il].ffn_post_norm , NULL ,
161+ LLM_NORM_RMS, il);
162+ cb (cur, " ffn_post_norm" , il);
163+
164+ cur = ggml_add (ctx0, cur, ffn_inp);
165+ cur = build_cvec (cur, il);
166+ cb (cur, " l_out" , il);
167+
168+ // input for next layer
169+ inpL = cur;
170+ }
171+
172+ cur = inpL;
173+
174+ cur = build_norm (cur,
175+ model.output_norm , NULL ,
176+ LLM_NORM_RMS, -1 );
177+ cb (cur, " result_norm" , -1 );
178+
179+ res->t_embd = cur;
180+
181+ // lm_head
182+ cur = build_lora_mm (model.output , cur);
183+ cb (cur, " result_output" , -1 );
184+ res->t_logits = cur;
185+
186+ ggml_build_forward_expand (gf, cur);
187+ }
0 commit comments