Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 39 additions & 17 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16430,46 +16430,68 @@ struct llm_build_lfm2 : public llm_graph_context {
ggml_tensor * cur,
llm_graph_input_rs * inp_recr,
int il) {
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
const uint32_t kv_head = mctx_cur->get_head();
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
const int64_t n_seqs = ubatch.n_seqs;
GGML_ASSERT(n_seqs != 0);
GGML_ASSERT(ubatch.equal_seqs);
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);

GGML_ASSERT(hparams.n_shortconv_l_cache > 1);
const uint32_t d_conv = hparams.n_shortconv_l_cache - 1;

// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);

auto * bcx = build_lora_mm(model.layers[il].shortconv.in_proj, cur);
cb(bcx, "model.layers.{}.conv.in_proj", il);

constexpr auto n_chunks = 3;
GGML_ASSERT(bcx->ne[0] % n_chunks == 0);
auto const chunk_size = bcx->ne[0] / n_chunks;
auto * b = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 0 * chunk_size * ggml_element_size(bcx));
auto * c = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 1 * chunk_size * ggml_element_size(bcx));
auto * x = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 2 * chunk_size * ggml_element_size(bcx));
auto * b = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], 0*chunk_size*ggml_element_size(bcx));
auto * c = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], 1*chunk_size*ggml_element_size(bcx));
auto * x = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], 2*chunk_size*ggml_element_size(bcx));

auto * bx = ggml_transpose(ctx0, ggml_mul(ctx0, b, x));

// read conv state directly, with build_rs generation is slower
ggml_tensor * conv_state = mctx_cur->get_r_l(il);
const int64_t n_seqs = ubatch.n_seqs;
ggml_tensor * conv = build_rs(inp_recr, gf, conv_state, hparams.n_embd_r(), n_seqs);
conv = ggml_reshape_3d(ctx0, conv_state, hparams.n_shortconv_l_cache - 1, hparams.n_embd, n_seqs);
// read conv state
auto * conv_state = mctx_cur->get_r_l(il);
auto * conv_rs = build_rs(inp_recr, gf, conv_state, hparams.n_embd_r(), n_seqs);
auto * conv = ggml_reshape_3d(ctx0, conv_rs, d_conv, hparams.n_embd, n_seqs);

bx = ggml_concat(ctx0, conv, bx, 0);
GGML_ASSERT(bx->ne[0] > conv->ne[0]);

auto * new_conv = ggml_view_2d(ctx0, bx, conv->ne[0], bx->ne[1], bx->nb[1], (bx->ne[0] - conv->ne[0]) * ggml_element_size(bx));
// last d_conv columns is a new conv state
auto * new_conv = ggml_view_3d(ctx0, bx, conv->ne[0], bx->ne[1], bx->ne[2], bx->nb[1], bx->nb[2], (bx->ne[0] - conv->ne[0])*ggml_element_size(bx));
GGML_ASSERT(ggml_are_same_shape(conv, new_conv));

// write conv state
ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_conv, conv_state));
// write new conv conv state
ggml_build_forward_expand(
gf,
ggml_cpy(
ctx0,
new_conv,
ggml_view_1d(
ctx0,
conv_state,
ggml_nelements(new_conv),
kv_head*d_conv*n_embd*ggml_element_size(new_conv)
)
)
);

auto * conv_kernel = model.layers[il].shortconv.conv;
GGML_ASSERT(hparams.n_shortconv_l_cache > 0);

// construct ssm_conv op
ggml_tensor * conv_out = ggml_ssm_conv(ctx0, bx, conv_kernel);
auto * conv_out = ggml_ssm_conv(ctx0, bx, conv_kernel);
cb(conv_out, "model.layers.{}.conv.conv", il);

auto * y = ggml_mul(ctx0, c, conv_out);

y = build_lora_mm(model.layers[il].shortconv.out_proj, y);
cb(y, "model.layers.{}.conv.out_proj", il);
// {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
y = ggml_reshape_2d(ctx0, y, y->ne[0], n_seq_tokens * n_seqs);

return y;
}
Expand Down
Loading