Skip to content

Commit 2645a7d

Browse files
committed
context : add save/load for recurrent context
ggml-ci
1 parent 08011c2 commit 2645a7d

File tree

2 files changed

+44
-4
lines changed

2 files changed

+44
-4
lines changed

src/llama-context.cpp

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3657,6 +3657,40 @@ ggml_tensor * llama_context_kv_self::build_inp_kq_mask_cross(
36573657
return inp_kq_mask_cross;
36583658
}
36593659

3660+
// state save/load
3661+
3662+
size_t llama_context_kv_self::state_get_data(llama_io_write_i & io) {
3663+
llama_context::state_get_data(io);
3664+
3665+
kv_self.state_write(io);
3666+
3667+
return io.n_bytes();
3668+
}
3669+
3670+
size_t llama_context_kv_self::state_set_data(llama_io_read_i & io) {
3671+
llama_context::state_set_data(io);
3672+
3673+
kv_self.state_read(io);
3674+
3675+
return io.n_bytes();
3676+
}
3677+
3678+
size_t llama_context_kv_self::state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) {
3679+
llama_context::state_seq_get_data(io, seq_id);
3680+
3681+
kv_self.state_write(io, seq_id);
3682+
3683+
return io.n_bytes();
3684+
}
3685+
3686+
size_t llama_context_kv_self::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) {
3687+
llama_context::state_seq_set_data(io, seq_id);
3688+
3689+
kv_self.state_read(io, seq_id);
3690+
3691+
return io.n_bytes();
3692+
}
3693+
36603694
//
36613695
// llama_context_recurrent
36623696
//
@@ -4527,31 +4561,31 @@ ggml_tensor * llama_context_recurrent::build_rwkv6_time_mix(
45274561

45284562
// state save/load
45294563

4530-
size_t llama_context_kv_self::state_get_data(llama_io_write_i & io) {
4564+
size_t llama_context_recurrent::state_get_data(llama_io_write_i & io) {
45314565
llama_context::state_get_data(io);
45324566

45334567
kv_self.state_write(io);
45344568

45354569
return io.n_bytes();
45364570
}
45374571

4538-
size_t llama_context_kv_self::state_set_data(llama_io_read_i & io) {
4572+
size_t llama_context_recurrent::state_set_data(llama_io_read_i & io) {
45394573
llama_context::state_set_data(io);
45404574

45414575
kv_self.state_read(io);
45424576

45434577
return io.n_bytes();
45444578
}
45454579

4546-
size_t llama_context_kv_self::state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) {
4580+
size_t llama_context_recurrent::state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) {
45474581
llama_context::state_seq_get_data(io, seq_id);
45484582

45494583
kv_self.state_write(io, seq_id);
45504584

45514585
return io.n_bytes();
45524586
}
45534587

4554-
size_t llama_context_kv_self::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) {
4588+
size_t llama_context_recurrent::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) {
45554589
llama_context::state_seq_set_data(io, seq_id);
45564590

45574591
kv_self.state_read(io, seq_id);

src/llama-context.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,12 @@ class llama_context_recurrent : public llama_context {
525525
bool worst_case) override;
526526

527527
protected:
528+
virtual size_t state_get_data(llama_io_write_i & io) override;
529+
virtual size_t state_set_data(llama_io_read_i & io) override;
530+
531+
virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) override;
532+
virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) override;
533+
528534
virtual void input_set(const llama_ubatch & ubatch) override;
529535

530536
// TODO: change name to something more meaningful -- does "KV cache" make sense for recurrent models?

0 commit comments

Comments
 (0)