@@ -14028,7 +14028,11 @@ struct llm_build_granite_hybrid : public llm_graph_context {
1402814028
1402914029 inpL = build_inp_embd(model.tok_embd);
1403014030
14031- auto * inp = build_inp_mem_hybrid();
14031+ const auto * mctx_hyb = static_cast<const llama_memory_hybrid_context *>(mctx);
14032+
14033+ auto * inp_rs = build_rs_inp(mctx_hyb->get_recr());
14034+
14035+ auto * inp_attn = build_attn_inp_kv_unified(mctx_hyb->get_attn());
1403214036
1403314037 ggml_tensor * inp_out_ids = build_inp_out_ids();
1403414038
@@ -14049,11 +14053,11 @@ struct llm_build_granite_hybrid : public llm_graph_context {
1404914053
1405014054 if (hparams.is_recurrent(il)) {
1405114055 // ssm layer //
14052- cur = build_mamba2_layer(inp , gf, cur, model, ubatch, il);
14056+ cur = build_mamba2_layer(inp_rs , gf, cur, model, ubatch, il);
1405314057 } else {
1405414058 // attention layer //
1405514059 cur = build_granite_attention_layer(
14056- gf, cur, inp_pos, inp , model,
14060+ gf, cur, inp_pos, inp_attn , model,
1405714061 n_embd_head, use_rope, il);
1405814062 }
1405914063
@@ -14092,12 +14096,12 @@ struct llm_build_granite_hybrid : public llm_graph_context {
1409214096 }
1409314097
1409414098 ggml_tensor * build_mamba2_layer(
14095- llm_graph_input_mem_hybrid * inp,
14096- ggml_cgraph * gf,
14097- ggml_tensor * cur,
14098- const llama_model & model,
14099- const llama_ubatch & ubatch,
14100- int il) const {
14099+ llm_graph_input_rs * inp,
14100+ ggml_cgraph * gf,
14101+ ggml_tensor * cur,
14102+ const llama_model & model,
14103+ const llama_ubatch & ubatch,
14104+ int il) const {
1410114105 const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
1410214106
1410314107 const auto kv_head = mctx_cur->get_head();
@@ -14221,14 +14225,14 @@ struct llm_build_granite_hybrid : public llm_graph_context {
1422114225 }
1422214226
1422314227 ggml_tensor * build_granite_attention_layer(
14224- ggml_cgraph * gf,
14225- ggml_tensor * cur,
14226- ggml_tensor * inp_pos,
14227- llm_graph_input_mem_hybrid * inp,
14228- const llama_model & model,
14229- const int64_t n_embd_head,
14230- const bool use_rope,
14231- const int il) {
14228+ ggml_cgraph * gf,
14229+ ggml_tensor * cur,
14230+ ggml_tensor * inp_pos,
14231+ llm_graph_input_attn_kv_unified * inp,
14232+ const llama_model & model,
14233+ const int64_t n_embd_head,
14234+ const bool use_rope,
14235+ const int il) {
1423214236
1423314237 // compute Q and K and (optionally) RoPE them
1423414238 ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
0 commit comments