Skip to content

Commit 4a29561

Browse files
committed
model : support n_seq > 1 for lfm2
1 parent c81f419 commit 4a29561

File tree

1 file changed

+39
-17
lines changed

1 file changed

+39
-17
lines changed

src/llama-model.cpp

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16430,46 +16430,68 @@ struct llm_build_lfm2 : public llm_graph_context {
1643016430
ggml_tensor * cur,
1643116431
llm_graph_input_rs * inp_recr,
1643216432
int il) {
16433-
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
16433+
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
16434+
const uint32_t kv_head = mctx_cur->get_head();
16435+
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
16436+
const int64_t n_seqs = ubatch.n_seqs;
16437+
GGML_ASSERT(n_seqs != 0);
16438+
GGML_ASSERT(ubatch.equal_seqs);
16439+
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
16440+
16441+
GGML_ASSERT(hparams.n_shortconv_l_cache > 1);
16442+
const uint32_t d_conv = hparams.n_shortconv_l_cache - 1;
16443+
16444+
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
16445+
cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
1643416446

1643516447
auto * bcx = build_lora_mm(model.layers[il].shortconv.in_proj, cur);
1643616448
cb(bcx, "model.layers.{}.conv.in_proj", il);
1643716449

1643816450
constexpr auto n_chunks = 3;
1643916451
GGML_ASSERT(bcx->ne[0] % n_chunks == 0);
1644016452
auto const chunk_size = bcx->ne[0] / n_chunks;
16441-
auto * b = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 0 * chunk_size * ggml_element_size(bcx));
16442-
auto * c = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 1 * chunk_size * ggml_element_size(bcx));
16443-
auto * x = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 2 * chunk_size * ggml_element_size(bcx));
16453+
auto * b = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], 0*chunk_size*ggml_element_size(bcx));
16454+
auto * c = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], 1*chunk_size*ggml_element_size(bcx));
16455+
auto * x = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], 2*chunk_size*ggml_element_size(bcx));
1644416456

1644516457
auto * bx = ggml_transpose(ctx0, ggml_mul(ctx0, b, x));
1644616458

16447-
// read conv state directly, with build_rs generation is slower
16448-
ggml_tensor * conv_state = mctx_cur->get_r_l(il);
16449-
const int64_t n_seqs = ubatch.n_seqs;
16450-
ggml_tensor * conv = build_rs(inp_recr, gf, conv_state, hparams.n_embd_r(), n_seqs);
16451-
conv = ggml_reshape_3d(ctx0, conv_state, hparams.n_shortconv_l_cache - 1, hparams.n_embd, n_seqs);
16459+
// read conv state
16460+
auto * conv_state = mctx_cur->get_r_l(il);
16461+
auto * conv_rs = build_rs(inp_recr, gf, conv_state, hparams.n_embd_r(), n_seqs);
16462+
auto * conv = ggml_reshape_3d(ctx0, conv_rs, d_conv, hparams.n_embd, n_seqs);
1645216463

1645316464
bx = ggml_concat(ctx0, conv, bx, 0);
1645416465
GGML_ASSERT(bx->ne[0] > conv->ne[0]);
1645516466

16456-
auto * new_conv = ggml_view_2d(ctx0, bx, conv->ne[0], bx->ne[1], bx->nb[1], (bx->ne[0] - conv->ne[0]) * ggml_element_size(bx));
16467+
// last d_conv columns is a new conv state
16468+
auto * new_conv = ggml_view_3d(ctx0, bx, conv->ne[0], bx->ne[1], bx->ne[2], bx->nb[1], bx->nb[2], (bx->ne[0] - conv->ne[0])*ggml_element_size(bx));
1645716469
GGML_ASSERT(ggml_are_same_shape(conv, new_conv));
1645816470

16459-
// write conv state
16460-
ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_conv, conv_state));
16471+
// write new conv conv state
16472+
ggml_build_forward_expand(
16473+
gf,
16474+
ggml_cpy(
16475+
ctx0,
16476+
new_conv,
16477+
ggml_view_1d(
16478+
ctx0,
16479+
conv_state,
16480+
ggml_nelements(new_conv),
16481+
kv_head*d_conv*n_embd*ggml_element_size(new_conv)
16482+
)
16483+
)
16484+
);
1646116485

1646216486
auto * conv_kernel = model.layers[il].shortconv.conv;
16463-
GGML_ASSERT(hparams.n_shortconv_l_cache > 0);
16464-
16465-
// construct ssm_conv op
16466-
ggml_tensor * conv_out = ggml_ssm_conv(ctx0, bx, conv_kernel);
16487+
auto * conv_out = ggml_ssm_conv(ctx0, bx, conv_kernel);
1646716488
cb(conv_out, "model.layers.{}.conv.conv", il);
1646816489

1646916490
auto * y = ggml_mul(ctx0, c, conv_out);
16470-
1647116491
y = build_lora_mm(model.layers[il].shortconv.out_proj, y);
1647216492
cb(y, "model.layers.{}.conv.out_proj", il);
16493+
// {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
16494+
y = ggml_reshape_2d(ctx0, y, y->ne[0], n_seq_tokens * n_seqs);
1647316495

1647416496
return y;
1647516497
}

0 commit comments

Comments
 (0)