Skip to content

Commit 64de434

Browse files
committed
Fixes from main branch
1 parent 7bedf4c commit 64de434

File tree

2 files changed

+13
-13
lines changed

2 files changed

+13
-13
lines changed

src/models/llm_build_mamba.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
#include "../llama-model.h"
1+
22
#include "../llama-graph.h"
3-
#include "llm_graph_context_mamba.h"
3+
#include "../llama-model.h"
44

5+
#include "llm_graph_context_mamba.h"
56
#include "llm_build_mamba.h"
7+
68
#include <cmath>
79

8-
llm_build_mamba::llm_build_mamba(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
10+
llm_build_mamba::llm_build_mamba(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) {
911
ggml_tensor * cur;
1012
ggml_tensor * inpL;
1113

@@ -18,22 +20,20 @@ llm_build_mamba::llm_build_mamba(const llama_model & model, const llm_graph_para
1820

1921
for (int il = 0; il < n_layer; ++il) {
2022
// norm
21-
cur = build_norm(inpL,
22-
model.layers[il].attn_norm, NULL,
23-
LLM_NORM_RMS, il);
23+
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
2424
cb(cur, "attn_norm", il);
2525

2626
if (model.arch == LLM_ARCH_MAMBA2) {
27-
// TODO: implement mamba2_layer inline
28-
// cur = build_mamba2_layer(rs_inp, cur, model, ubatch, il);
27+
cur = build_mamba2_layer(rs_inp, cur, model, ubatch, il);
2928
} else {
30-
// TODO: implement mamba_layer inline
31-
// cur = build_mamba_layer(rs_inp, cur, model, ubatch, il);
29+
cur = build_mamba_layer(rs_inp, cur, model, ubatch, il);
3230
}
31+
3332
if (il == n_layer - 1 && inp_out_ids) {
34-
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
33+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
3534
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
3635
}
36+
3737
// residual
3838
cur = ggml_add(ctx0, cur, inpL);
3939

@@ -43,7 +43,7 @@ llm_build_mamba::llm_build_mamba(const llama_model & model, const llm_graph_para
4343
// input for next layer
4444
inpL = cur;
4545
}
46-
;
46+
4747
// final rmsnorm
4848
cur = build_norm(inpL, model.output_norm, NULL, LLM_NORM_RMS, -1);
4949

src/models/llm_build_mamba.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@
55

66
#include <cmath>
77

8-
struct llm_build_mamba : public llm_graph_context {
8+
struct llm_build_mamba : public llm_graph_context_mamba {
99
llm_build_mamba(const llama_model & model, const llm_graph_params & params);
1010
};

0 commit comments

Comments
 (0)