@@ -79,13 +79,16 @@ void Session::setInitialPrompt(std::span<const Token> initialPrompt) {
7979 }
8080
8181 doDecode (initialPrompt, Source::InitialPrompt);
82+ m_state.m_phase = State::Phase::Generating;
8283}
8384
8485void Session::pushPrompt (std::span<const Token> prompt) {
8586 if (m_state.m_phase != State::Phase::Generating) {
8687 throw_ex{} << " Session hasn't started yet" ;
8788 }
8889
90+ flushPendingState ();
91+
8992 if (!prompt.empty ()) {
9093 auto & sampler = m_instance.sampler ();
9194 auto & model = m_instance.model ();
@@ -108,10 +111,7 @@ Token Session::getToken() {
108111 throw_ex{} << " Session hasn't started yet" ;
109112 }
110113
111- if (m_state.m_currToken != Token_Invalid) {
112- // first yield, then decode, thus we don't decode if the session is aborted
113- doDecode ({&m_state.m_currToken , 1 }, Source::Generated);
114- }
114+ flushPendingState ();
115115
116116 auto & sampler = m_instance.sampler ();
117117 auto & vocab = m_instance.model ().vocab ();
@@ -131,6 +131,8 @@ std::vector<uint8_t> Session::getState() {
131131 throw_ex{} << " Session hasn't started yet" ;
132132 }
133133
134+ flushPendingState ();
135+
134136 const auto size = llama_state_get_size (m_ctx);
135137 std::vector<uint8_t > state (size);
136138 if (llama_state_get_data (m_ctx, state.data (), size) != size) {
@@ -147,6 +149,8 @@ bool Session::setState(std::span<uint8_t> state) {
147149 if (llama_state_set_data (m_ctx, state.data (), state.size ()) != state.size ()) {
148150 throw_ex{} << " Failed to set state" ;
149151 }
152+
153+ m_state.m_phase = State::Phase::Generating;
150154 return true ;
151155}
152156
@@ -235,4 +239,11 @@ void Session::doDecode(std::span<const Token> tokens, Source src) {
235239
236240}
237241
242+ void Session::flushPendingState () {
243+ if (m_state.m_currToken != Token_Invalid) {
244+ // first yield, then decode, thus we don't decode if the session is aborted
245+ doDecode ({&m_state.m_currToken , 1 }, Source::Generated);
246+ m_state.m_currToken = Token_Invalid;
247+ }
248+ }
238249} // namespace ac::llama
0 commit comments