@@ -1657,30 +1657,30 @@ size_t llama_context::state_set_data(const uint8_t * src, size_t size) {
16571657 }
16581658}
16591659
1660- size_t llama_context::state_seq_get_size (llama_seq_id seq_id) {
1660+ size_t llama_context::state_seq_get_size (llama_seq_id seq_id, llama_state_seq_flags flags ) {
16611661 llama_io_write_dummy io;
16621662 try {
1663- return state_seq_write_data (io, seq_id);
1663+ return state_seq_write_data (io, seq_id, flags );
16641664 } catch (const std::exception & err) {
16651665 LLAMA_LOG_ERROR (" %s: error getting state size: %s\n " , __func__, err.what ());
16661666 return 0 ;
16671667 }
16681668}
16691669
1670- size_t llama_context::state_seq_get_data (llama_seq_id seq_id, uint8_t * dst, size_t size) {
1670+ size_t llama_context::state_seq_get_data (llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags ) {
16711671 llama_io_write_buffer io (dst, size);
16721672 try {
1673- return state_seq_write_data (io, seq_id);
1673+ return state_seq_write_data (io, seq_id, flags );
16741674 } catch (const std::exception & err) {
16751675 LLAMA_LOG_ERROR (" %s: error saving state: %s\n " , __func__, err.what ());
16761676 return 0 ;
16771677 }
16781678}
16791679
1680- size_t llama_context::state_seq_set_data (llama_seq_id seq_id, const uint8_t * src, size_t size) {
1680+ size_t llama_context::state_seq_set_data (llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags ) {
16811681 llama_io_read_buffer io (src, size);
16821682 try {
1683- return state_seq_read_data (io, seq_id);
1683+ return state_seq_read_data (io, seq_id, flags );
16841684 } catch (const std::exception & err) {
16851685 LLAMA_LOG_ERROR (" %s: error loading state: %s\n " , __func__, err.what ());
16861686 return 0 ;
@@ -1778,7 +1778,7 @@ size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * file
17781778 {
17791779 const size_t state_size = file.size () - file.tell ();
17801780 llama_io_read_file io (&file);
1781- const size_t nread = state_seq_read_data (io, seq_id);
1781+ const size_t nread = state_seq_read_data (io, seq_id, 0 );
17821782 if (!nread) {
17831783 LLAMA_LOG_ERROR (" %s: failed to restore sequence state\n " , __func__);
17841784 return 0 ;
@@ -1802,7 +1802,7 @@ size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * file
18021802
18031803 // save the context state using stream saving
18041804 llama_io_write_file io (&file);
1805- state_seq_write_data (io, seq_id);
1805+ state_seq_write_data (io, seq_id, 0 );
18061806
18071807 const size_t res = file.tell ();
18081808 GGML_ASSERT (res == sizeof (uint32_t ) * 3 + sizeof (llama_token) * n_token_count + io.n_bytes ());
@@ -1971,21 +1971,21 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
19711971 return io.n_bytes ();
19721972}
19731973
1974- size_t llama_context::state_seq_write_data (llama_io_write_i & io, llama_seq_id seq_id) {
1974+ size_t llama_context::state_seq_write_data (llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags ) {
19751975 GGML_UNUSED (seq_id);
19761976
19771977 if (memory) {
1978- memory->state_write (io, seq_id);
1978+ memory->state_write (io, seq_id, flags );
19791979 }
19801980
19811981 return io.n_bytes ();
19821982}
19831983
1984- size_t llama_context::state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id) {
1984+ size_t llama_context::state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags ) {
19851985 GGML_UNUSED (seq_id);
19861986
19871987 if (memory) {
1988- memory->state_read (io, seq_id);
1988+ memory->state_read (io, seq_id, flags );
19891989 }
19901990
19911991 return io.n_bytes ();
@@ -2801,19 +2801,31 @@ bool llama_state_save_file(llama_context * ctx, const char * path_session, const
28012801}
28022802
28032803size_t llama_state_seq_get_size (llama_context * ctx, llama_seq_id seq_id) {
2804- return ctx-> state_seq_get_size ( seq_id);
2804+ return llama_state_seq_get_size_ext (ctx, seq_id, 0 );
28052805}
28062806
28072807size_t llama_state_seq_get_data (llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
2808+ return llama_state_seq_get_data_ext (ctx, dst, size, seq_id, 0 );
2809+ }
2810+
2811+ size_t llama_state_seq_set_data (llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
2812+ return llama_state_seq_set_data_ext (ctx, src, size, seq_id, 0 );
2813+ }
2814+
2815+ size_t llama_state_seq_get_size_ext (llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) {
2816+ return ctx->state_seq_get_size (seq_id, flags);
2817+ }
2818+
2819+ size_t llama_state_seq_get_data_ext (llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
28082820 ctx->synchronize ();
28092821
2810- return ctx->state_seq_get_data (seq_id, dst, size);
2822+ return ctx->state_seq_get_data (seq_id, dst, size, flags );
28112823}
28122824
2813- size_t llama_state_seq_set_data (llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
2825+ size_t llama_state_seq_set_data_ext (llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags ) {
28142826 ctx->synchronize ();
28152827
2816- return ctx->state_seq_set_data (seq_id, src, size);
2828+ return ctx->state_seq_set_data (seq_id, src, size, flags );
28172829}
28182830
28192831size_t llama_state_seq_save_file (llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
0 commit comments