@@ -16554,46 +16554,68 @@ struct llm_build_lfm2 : public llm_graph_context {
1655416554                                        ggml_tensor        * cur,
1655516555                                        llm_graph_input_rs * inp_recr,
1655616556                                        int                il) {
16557-         const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
16557+         const auto *   mctx_cur     = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
16558+         const uint32_t kv_head      = mctx_cur->get_head();
16559+         const int64_t  n_seq_tokens = ubatch.n_seq_tokens;
16560+         const int64_t  n_seqs       = ubatch.n_seqs;
16561+         GGML_ASSERT(n_seqs != 0);
16562+         GGML_ASSERT(ubatch.equal_seqs);
16563+         GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
16564+ 
16565+         GGML_ASSERT(hparams.n_shortconv_l_cache > 1);
16566+         const uint32_t d_conv = hparams.n_shortconv_l_cache - 1;
16567+ 
16568+         // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
16569+         cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
1655816570
1655916571        auto * bcx = build_lora_mm(model.layers[il].shortconv.in_proj, cur);
1656016572        cb(bcx, "model.layers.{}.conv.in_proj", il);
1656116573
1656216574        constexpr auto n_chunks = 3;
1656316575        GGML_ASSERT(bcx->ne[0] % n_chunks == 0);
1656416576        auto const chunk_size = bcx->ne[0] / n_chunks;
16565-         auto * b = ggml_view_2d (ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 0 *  chunk_size *  ggml_element_size(bcx));
16566-         auto * c = ggml_view_2d (ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 1 *  chunk_size *  ggml_element_size(bcx));
16567-         auto * x = ggml_view_2d (ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 2 *  chunk_size *  ggml_element_size(bcx));
16577+         auto * b = ggml_view_3d (ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx-> nb[1], bcx->nb[2], 0* chunk_size* ggml_element_size(bcx));
16578+         auto * c = ggml_view_3d (ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx-> nb[1], bcx->nb[2], 1* chunk_size* ggml_element_size(bcx));
16579+         auto * x = ggml_view_3d (ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx-> nb[1], bcx->nb[2], 2* chunk_size* ggml_element_size(bcx));
1656816580
1656916581        auto * bx = ggml_transpose(ctx0, ggml_mul(ctx0, b, x));
1657016582
16571-         // read conv state directly, with build_rs generation is slower
16572-         ggml_tensor * conv_state = mctx_cur->get_r_l(il);
16573-         const int64_t n_seqs  = ubatch.n_seqs;
16574-         ggml_tensor * conv = build_rs(inp_recr, gf, conv_state, hparams.n_embd_r(), n_seqs);
16575-         conv = ggml_reshape_3d(ctx0, conv_state, hparams.n_shortconv_l_cache - 1, hparams.n_embd, n_seqs);
16583+         // read conv state
16584+         auto * conv_state = mctx_cur->get_r_l(il);
16585+         auto * conv_rs    = build_rs(inp_recr, gf, conv_state, hparams.n_embd_r(), n_seqs);
16586+         auto * conv       = ggml_reshape_3d(ctx0, conv_rs, d_conv, hparams.n_embd, n_seqs);
1657616587
1657716588        bx = ggml_concat(ctx0, conv, bx, 0);
1657816589        GGML_ASSERT(bx->ne[0] > conv->ne[0]);
1657916590
16580-         auto * new_conv = ggml_view_2d(ctx0, bx, conv->ne[0], bx->ne[1], bx->nb[1], (bx->ne[0] - conv->ne[0]) * ggml_element_size(bx));
16591+         // last d_conv columns is a new conv state
16592+         auto * new_conv = ggml_view_3d(ctx0, bx, conv->ne[0], bx->ne[1], bx->ne[2], bx->nb[1], bx->nb[2], (bx->ne[0] - conv->ne[0])*ggml_element_size(bx));
1658116593        GGML_ASSERT(ggml_are_same_shape(conv, new_conv));
1658216594
16583-         // write conv state
16584-         ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_conv, conv_state));
16595+         // write new conv conv state
16596+         ggml_build_forward_expand(
16597+                 gf,
16598+                 ggml_cpy(
16599+                     ctx0,
16600+                     new_conv,
16601+                     ggml_view_1d(
16602+                         ctx0,
16603+                         conv_state,
16604+                         ggml_nelements(new_conv),
16605+                         kv_head*d_conv*n_embd*ggml_element_size(new_conv)
16606+                         )
16607+                     )
16608+                 );
1658516609
1658616610        auto * conv_kernel = model.layers[il].shortconv.conv;
16587-         GGML_ASSERT(hparams.n_shortconv_l_cache > 0);
16588- 
16589-         // construct ssm_conv op
16590-         ggml_tensor * conv_out = ggml_ssm_conv(ctx0, bx, conv_kernel);
16611+         auto * conv_out = ggml_ssm_conv(ctx0, bx, conv_kernel);
1659116612        cb(conv_out, "model.layers.{}.conv.conv", il);
1659216613
1659316614        auto * y = ggml_mul(ctx0, c, conv_out);
16594- 
1659516615        y = build_lora_mm(model.layers[il].shortconv.out_proj, y);
1659616616        cb(y, "model.layers.{}.conv.out_proj", il);
16617+         // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
16618+         y = ggml_reshape_2d(ctx0, y, y->ne[0], n_seq_tokens * n_seqs);
1659716619
1659816620        return y;
1659916621    }
0 commit comments