@@ -17,12 +17,11 @@ struct ggml_tensor;
1717struct llama_ubatch ;
1818struct llama_cparams ;
1919
20- struct llama_memory_context_i ;
20+ struct llama_memory_state_i ;
2121
22- class llama_kv_cache_unified_context ;
23- class llama_kv_cache_unified_iswa_context ;
24- class llama_memory_recurrent_context ;
25- class llama_memory_hybrid_context ;
22+ class llama_kv_cache_unified_state ;
23+ class llama_kv_cache_unified_iswa_state ;
24+ class llama_kv_cache_recurrent_state ;
2625
2726// certain models (typically multi-modal) can produce different types of graphs
2827enum llm_graph_type {
@@ -38,7 +37,6 @@ enum llm_ffn_op_type {
3837 LLM_FFN_RELU_SQR,
3938 LLM_FFN_SWIGLU,
4039 LLM_FFN_GEGLU,
41- LLM_FFN_REGLU,
4240};
4341
4442enum llm_ffn_gate_type {
@@ -96,14 +94,14 @@ class llm_graph_input_embd : public llm_graph_input_i {
9694
9795class llm_graph_input_pos : public llm_graph_input_i {
9896public:
99- llm_graph_input_pos (uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
97+ llm_graph_input_pos (int64_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
10098 virtual ~llm_graph_input_pos () = default ;
10199
102100 void set_input (const llama_ubatch * ubatch) override ;
103101
104102 ggml_tensor * pos = nullptr ; // I32 [n_batch]
105103
106- const uint32_t n_pos_per_embd = 1 ;
104+ const int64_t n_pos_per_embd = 1 ;
107105};
108106
109107// temperature tuning, used by llama4
@@ -137,16 +135,15 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
137135public:
138136 llm_graph_input_pos_bucket_kv (
139137 const llama_hparams & hparams,
140- const llama_kv_cache_unified_context * mctx ) : hparams(hparams), mctx(mctx ) {}
138+ const llama_kv_cache_unified_state * kv_state ) : hparams(hparams), kv_state(kv_state ) {}
141139 virtual ~llm_graph_input_pos_bucket_kv () = default ;
142140
143141 void set_input (const llama_ubatch * ubatch) override ;
144142
145143 ggml_tensor * pos_bucket = nullptr ; // I32 [n_kv, n_batch]
146144
147145 const llama_hparams & hparams;
148-
149- const llama_kv_cache_unified_context * mctx;
146+ const llama_kv_cache_unified_state * kv_state;
150147};
151148
152149class llm_graph_input_out_ids : public llm_graph_input_i {
@@ -191,16 +188,28 @@ class llm_graph_input_cls : public llm_graph_input_i {
191188 const llama_cparams & cparams;
192189};
193190
194- class llm_graph_input_rs : public llm_graph_input_i {
191+ class llm_graph_input_s_copy : public llm_graph_input_i {
195192public:
196- llm_graph_input_rs (const llama_memory_recurrent_context * mctx ) : mctx(mctx ) {}
197- virtual ~llm_graph_input_rs () = default ;
193+ llm_graph_input_s_copy (const llama_kv_cache_recurrent_state * kv_state ) : kv_state(kv_state ) {}
194+ virtual ~llm_graph_input_s_copy () = default ;
198195
199196 void set_input (const llama_ubatch * ubatch) override ;
200197
201198 ggml_tensor * s_copy; // I32 [kv_size]
202199
203- const llama_memory_recurrent_context * mctx;
200+ const llama_kv_cache_recurrent_state * kv_state;
201+ };
202+
203+ class llm_graph_input_s_mask : public llm_graph_input_i {
204+ public:
205+ llm_graph_input_s_mask (const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
206+ virtual ~llm_graph_input_s_mask () = default ;
207+
208+ void set_input (const llama_ubatch * ubatch) override ;
209+
210+ ggml_tensor * s_mask; // F32 [1, n_kv]
211+
212+ const llama_kv_cache_recurrent_state * kv_state;
204213};
205214
206215class llm_graph_input_cross_embd : public llm_graph_input_i {
@@ -240,10 +249,10 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
240249 llm_graph_input_attn_kv_unified (
241250 const llama_hparams & hparams,
242251 const llama_cparams & cparams,
243- const llama_kv_cache_unified_context * mctx ) :
252+ const llama_kv_cache_unified_state * kv_state ) :
244253 hparams (hparams),
245254 cparams (cparams),
246- mctx (mctx ) {
255+ kv_state (kv_state ) {
247256 }
248257 ~llm_graph_input_attn_kv_unified () = default ;
249258
@@ -257,18 +266,18 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
257266 const llama_hparams & hparams;
258267 const llama_cparams & cparams;
259268
260- const llama_kv_cache_unified_context * mctx ;
269+ const llama_kv_cache_unified_state * kv_state ;
261270};
262271
263272class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
264273public:
265274 llm_graph_input_attn_kv_unified_iswa (
266275 const llama_hparams & hparams,
267276 const llama_cparams & cparams,
268- const llama_kv_cache_unified_iswa_context * mctx ) :
277+ const llama_kv_cache_unified_iswa_state * kv_state ) :
269278 hparams (hparams),
270279 cparams (cparams),
271- mctx (mctx ) {
280+ kv_state (kv_state ) {
272281 }
273282 ~llm_graph_input_attn_kv_unified_iswa () = default ;
274283
@@ -285,7 +294,7 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
285294 const llama_hparams & hparams;
286295 const llama_cparams & cparams;
287296
288- const llama_kv_cache_unified_iswa_context * mctx ;
297+ const llama_kv_cache_unified_iswa_state * kv_state ;
289298};
290299
291300class llm_graph_input_attn_cross : public llm_graph_input_i {
@@ -303,44 +312,6 @@ class llm_graph_input_attn_cross : public llm_graph_input_i {
303312 const llama_cross * cross = nullptr ;
304313};
305314
306- class llm_graph_input_mem_hybrid : public llm_graph_input_i {
307- public:
308- llm_graph_input_mem_hybrid (
309- const llama_hparams & hparams,
310- const llama_cparams & cparams,
311- const llama_memory_hybrid_context * mctx) :
312- hparams (hparams),
313- cparams (cparams),
314- mctx (mctx) {
315- }
316- virtual ~llm_graph_input_mem_hybrid () = default ;
317-
318- void set_input (const llama_ubatch * ubatch) override ;
319-
320- ggml_tensor * s_copy; // I32 [kv_size]
321-
322- ggml_tensor * get_kq_mask () const { return self_kq_mask_cnv; }
323-
324- ggml_tensor * self_kq_mask = nullptr ; // F32 [n_kv, n_batch]
325- ggml_tensor * self_kq_mask_cnv = nullptr ; // [n_kv, n_batch]
326-
327- const llama_hparams & hparams;
328- const llama_cparams & cparams;
329-
330- const llama_memory_hybrid_context * mctx;
331- };
332-
333- // TODO: remove this when ggml_scale_add is implemented
334- class llm_graph_input_one : public llm_graph_input_i {
335- public:
336- llm_graph_input_one () {}
337- virtual ~llm_graph_input_one () = default ;
338-
339- void set_input (const llama_ubatch *) override ;
340-
341- ggml_tensor * one = nullptr ; // F32
342- };
343-
344315//
345316// llm_graph_result
346317//
@@ -414,12 +385,12 @@ struct llm_graph_params {
414385 ggml_backend_sched_t sched;
415386 ggml_backend_t backend_cpu;
416387
417- const llama_adapter_cvec * cvec;
418- const llama_adapter_loras * loras;
419- const llama_memory_context_i * mctx ;
420- const llama_cross * cross;
388+ const llama_adapter_cvec * cvec;
389+ const llama_adapter_loras * loras;
390+ const llama_memory_state_i * mstate ;
391+ const llama_cross * cross;
421392
422- uint32_t n_outputs;
393+ int32_t n_outputs;
423394
424395 const llm_graph_cb & cb;
425396};
@@ -453,8 +424,8 @@ struct llm_graph_context {
453424 const float norm_eps;
454425 const float norm_rms_eps;
455426
456- const int64_t n_tokens;
457- const int64_t n_outputs;
427+ const int32_t n_tokens;
428+ const int32_t n_outputs;
458429 const int32_t n_ctx_orig; // yarn
459430
460431 const enum llama_pooling_type pooling_type;
@@ -466,17 +437,18 @@ struct llm_graph_context {
466437
467438 ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
468439
469- const llama_adapter_cvec * cvec;
470- const llama_adapter_loras * loras;
471- const llama_memory_context_i * mctx ;
472- const llama_cross * cross;
440+ const llama_adapter_cvec * cvec;
441+ const llama_adapter_loras * loras;
442+ const llama_memory_state_i * mstate ;
443+ const llama_cross * cross;
473444
474445 const llm_graph_cb & cb_func;
475446
476447 std::unique_ptr<llm_graph_result> res;
477448
478449 llm_graph_context (const llm_graph_params & params);
479- virtual ~llm_graph_context () = default ;
450+
451+ int64_t n_pos_per_embd () const ;
480452
481453 void cb (ggml_tensor * cur, const char * name, int il) const ;
482454
@@ -548,14 +520,14 @@ struct llm_graph_context {
548520 ggml_tensor * build_inp_out_ids () const ;
549521 ggml_tensor * build_inp_mean () const ;
550522 ggml_tensor * build_inp_cls () const ;
523+ ggml_tensor * build_inp_s_copy () const ;
524+ ggml_tensor * build_inp_s_mask () const ;
551525
552526 ggml_tensor * build_inp_cross_embd () const ;
553527 ggml_tensor * build_inp_pos_bucket_enc () const ;
554528 ggml_tensor * build_inp_pos_bucket_dec () const ;
555529 ggml_tensor * build_pos_bias (ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const ;
556530
557- llm_graph_input_mem_hybrid * build_inp_mem_hybrid () const ;
558-
559531 //
560532 // attention
561533 //
@@ -602,15 +574,14 @@ struct llm_graph_context {
602574
603575 llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa () const ;
604576
605- // note: if k_cur or v_cur are not provided, they will not be stored in the memory
606577 ggml_tensor * build_attn (
607578 llm_graph_input_attn_kv_unified_iswa * inp,
608579 ggml_cgraph * gf,
609580 ggml_tensor * wo,
610581 ggml_tensor * wo_b,
611582 ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
612- ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
613- ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
583+ ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
584+ ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
614585 ggml_tensor * kq_b,
615586 ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
616587 float kq_scale,
@@ -631,62 +602,23 @@ struct llm_graph_context {
631602 float kq_scale,
632603 int il) const ;
633604
634- ggml_tensor * build_attn (
635- llm_graph_input_mem_hybrid * inp,
636- ggml_cgraph * gf,
637- ggml_tensor * wo,
638- ggml_tensor * wo_b,
639- ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
640- ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
641- ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
642- ggml_tensor * kq_b,
643- ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
644- float kq_scale,
645- int il) const ;
646605 //
647606 // recurrent
648607 //
649608
650- // TODO: avoid notion of "kv"
651- // TODO: move this implementation to llama_memory_recurrent.
652- // this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
653- // when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
654- // implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
655- // `llama_memory_recurrent`
656- ggml_tensor * build_rs (
657- ggml_cgraph * gf,
658- ggml_tensor * s,
659- ggml_tensor * state_copy,
660- int32_t state_size,
661- int32_t n_seqs,
662- uint32_t n_kv,
663- uint32_t kv_head,
664- uint32_t kv_size,
665- int32_t rs_zero,
666- bool avoid_copies = false ) const ;
667-
668- llm_graph_input_rs * build_rs_inp () const ;
669-
670- ggml_tensor * build_rs (
671- llm_graph_input_rs * inp,
672- ggml_cgraph * gf,
673- ggml_tensor * s,
674- int32_t state_size,
675- int32_t n_seqs,
676- bool avoid_copies = false ) const ;
677-
678- ggml_tensor * build_rs (
679- llm_graph_input_mem_hybrid * inp,
680- ggml_cgraph * gf,
681- ggml_tensor * s,
682- int32_t state_size,
683- int32_t n_seqs,
684- bool avoid_copies = false ) const ;
609+ ggml_tensor * build_copy_mask_state (
610+ ggml_cgraph * gf,
611+ ggml_tensor * s,
612+ ggml_tensor * state_copy,
613+ ggml_tensor * state_mask,
614+ int32_t n_state,
615+ int32_t n_seqs) const ;
685616
686617 ggml_tensor * build_rwkv_token_shift_load (
687- llm_graph_input_rs * inp,
688- ggml_cgraph * gf,
689- const llama_ubatch & ubatch,
618+ ggml_cgraph * gf,
619+ ggml_tensor * state_copy,
620+ ggml_tensor * state_mask,
621+ const llama_ubatch & ubatch,
690622 int il) const ;
691623
692624 ggml_tensor * build_rwkv_token_shift_store (
0 commit comments