1- # include " ../llama-model.h "
1+
22#include " ../llama-graph.h"
3- #include " llm_graph_context_mamba .h"
3+ #include " ../llama-model .h"
44
5+ #include " llm_graph_context_mamba.h"
56#include " llm_build_mamba.h"
7+
68#include < cmath>
79
8- llm_build_mamba::llm_build_mamba (const llama_model & model, const llm_graph_params & params) : llm_graph_context (params) {
10+ llm_build_mamba::llm_build_mamba (const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba (params) {
911 ggml_tensor * cur;
1012 ggml_tensor * inpL;
1113
@@ -18,22 +20,20 @@ llm_build_mamba::llm_build_mamba(const llama_model & model, const llm_graph_para
1820
1921 for (int il = 0 ; il < n_layer; ++il) {
2022 // norm
21- cur = build_norm (inpL,
22- model.layers [il].attn_norm , NULL ,
23- LLM_NORM_RMS, il);
23+ cur = build_norm (inpL, model.layers [il].attn_norm , NULL , LLM_NORM_RMS, il);
2424 cb (cur, " attn_norm" , il);
2525
2626 if (model.arch == LLM_ARCH_MAMBA2) {
27- // TODO: implement mamba2_layer inline
28- // cur = build_mamba2_layer(rs_inp, cur, model, ubatch, il);
27+ cur = build_mamba2_layer (rs_inp, cur, model, ubatch, il);
2928 } else {
30- // TODO: implement mamba_layer inline
31- // cur = build_mamba_layer(rs_inp, cur, model, ubatch, il);
29+ cur = build_mamba_layer (rs_inp, cur, model, ubatch, il);
3230 }
31+
3332 if (il == n_layer - 1 && inp_out_ids) {
34- cur = ggml_get_rows (ctx0, cur, inp_out_ids);
33+ cur = ggml_get_rows (ctx0, cur, inp_out_ids);
3534 inpL = ggml_get_rows (ctx0, inpL, inp_out_ids);
3635 }
36+
3737 // residual
3838 cur = ggml_add (ctx0, cur, inpL);
3939
@@ -43,7 +43,7 @@ llm_build_mamba::llm_build_mamba(const llama_model & model, const llm_graph_para
4343 // input for next layer
4444 inpL = cur;
4545 }
46- ;
46+
4747 // final rmsnorm
4848 cur = build_norm (inpL, model.output_norm , NULL , LLM_NORM_RMS, -1 );
4949
0 commit comments