@@ -16564,46 +16564,68 @@ struct llm_build_lfm2 : public llm_graph_context {
1656416564 ggml_tensor * cur,
1656516565 llm_graph_input_rs * inp_recr,
1656616566 int il) {
16567- const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
16567+ const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
16568+ const uint32_t kv_head = mctx_cur->get_head();
16569+ const int64_t n_seq_tokens = ubatch.n_seq_tokens;
16570+ const int64_t n_seqs = ubatch.n_seqs;
16571+ GGML_ASSERT(n_seqs != 0);
16572+ GGML_ASSERT(ubatch.equal_seqs);
16573+ GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
16574+
16575+ GGML_ASSERT(hparams.n_shortconv_l_cache > 1);
16576+ const uint32_t d_conv = hparams.n_shortconv_l_cache - 1;
16577+
16578+ // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
16579+ cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
1656816580
1656916581 auto * bcx = build_lora_mm(model.layers[il].shortconv.in_proj, cur);
1657016582 cb(bcx, "model.layers.{}.conv.in_proj", il);
1657116583
1657216584 constexpr auto n_chunks = 3;
1657316585 GGML_ASSERT(bcx->ne[0] % n_chunks == 0);
1657416586 auto const chunk_size = bcx->ne[0] / n_chunks;
16575- auto * b = ggml_view_2d (ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 0 * chunk_size * ggml_element_size(bcx));
16576- auto * c = ggml_view_2d (ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 1 * chunk_size * ggml_element_size(bcx));
16577- auto * x = ggml_view_2d (ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 2 * chunk_size * ggml_element_size(bcx));
16587+ auto * b = ggml_view_3d (ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx-> nb[1], bcx->nb[2], 0* chunk_size* ggml_element_size(bcx));
16588+ auto * c = ggml_view_3d (ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx-> nb[1], bcx->nb[2], 1* chunk_size* ggml_element_size(bcx));
16589+ auto * x = ggml_view_3d (ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx-> nb[1], bcx->nb[2], 2* chunk_size* ggml_element_size(bcx));
1657816590
1657916591 auto * bx = ggml_transpose(ctx0, ggml_mul(ctx0, b, x));
1658016592
16581- // read conv state directly, with build_rs generation is slower
16582- ggml_tensor * conv_state = mctx_cur->get_r_l(il);
16583- const int64_t n_seqs = ubatch.n_seqs;
16584- ggml_tensor * conv = build_rs(inp_recr, gf, conv_state, hparams.n_embd_r(), n_seqs);
16585- conv = ggml_reshape_3d(ctx0, conv_state, hparams.n_shortconv_l_cache - 1, hparams.n_embd, n_seqs);
16593+ // read conv state
16594+ auto * conv_state = mctx_cur->get_r_l(il);
16595+ auto * conv_rs = build_rs(inp_recr, gf, conv_state, hparams.n_embd_r(), n_seqs);
16596+ auto * conv = ggml_reshape_3d(ctx0, conv_rs, d_conv, hparams.n_embd, n_seqs);
1658616597
1658716598 bx = ggml_concat(ctx0, conv, bx, 0);
1658816599 GGML_ASSERT(bx->ne[0] > conv->ne[0]);
1658916600
16590- auto * new_conv = ggml_view_2d(ctx0, bx, conv->ne[0], bx->ne[1], bx->nb[1], (bx->ne[0] - conv->ne[0]) * ggml_element_size(bx));
16601+ // last d_conv columns is a new conv state
16602+ auto * new_conv = ggml_view_3d(ctx0, bx, conv->ne[0], bx->ne[1], bx->ne[2], bx->nb[1], bx->nb[2], (bx->ne[0] - conv->ne[0])*ggml_element_size(bx));
1659116603 GGML_ASSERT(ggml_are_same_shape(conv, new_conv));
1659216604
16593- // write conv state
16594- ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_conv, conv_state));
16605+ // write new conv conv state
16606+ ggml_build_forward_expand(
16607+ gf,
16608+ ggml_cpy(
16609+ ctx0,
16610+ new_conv,
16611+ ggml_view_1d(
16612+ ctx0,
16613+ conv_state,
16614+ ggml_nelements(new_conv),
16615+ kv_head*d_conv*n_embd*ggml_element_size(new_conv)
16616+ )
16617+ )
16618+ );
1659516619
1659616620 auto * conv_kernel = model.layers[il].shortconv.conv;
16597- GGML_ASSERT(hparams.n_shortconv_l_cache > 0);
16598-
16599- // construct ssm_conv op
16600- ggml_tensor * conv_out = ggml_ssm_conv(ctx0, bx, conv_kernel);
16621+ auto * conv_out = ggml_ssm_conv(ctx0, bx, conv_kernel);
1660116622 cb(conv_out, "model.layers.{}.conv.conv", il);
1660216623
1660316624 auto * y = ggml_mul(ctx0, c, conv_out);
16604-
1660516625 y = build_lora_mm(model.layers[il].shortconv.out_proj, y);
1660616626 cb(y, "model.layers.{}.conv.out_proj", il);
16627+ // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
16628+ y = ggml_reshape_2d(ctx0, y, y->ne[0], n_seq_tokens * n_seqs);
1660716629
1660816630 return y;
1660916631 }
0 commit comments