@@ -35,27 +35,27 @@ Session::Session(Instance& instance, InitParams params)
3535 sampler.perfReset ();
3636
3737 const auto ctxLen = llama_n_ctx (lctx);
38- maxTokens = ctxLen - 4 ; // (#16)
38+ m_state. maxTokens = ctxLen - 4 ; // (#16)
3939}
4040
4141void Session::setInitialPrompt (std::span<const Token> initialPrompt) {
42+ if (m_state.m_phase != State::Phase::Initial) {
43+ throw_ex{} << " Session already started" ;
44+ }
45+
4246 Token initialToken; // used to reset the initial prompt to a single token
4347
4448 auto lctx = m_instance.ctx ();
4549 const auto ctxLen = llama_n_ctx (lctx);
4650 const auto tokenBos = llama_token_bos (m_instance.model ().lmodel ());
47- numKeep = std::min (uint32_t (initialPrompt.size ()), maxTokens); // number of tokens to keep in the context in case we overflow
51+ m_state. numKeep = std::min (uint32_t (initialPrompt.size ()), m_state. maxTokens ); // number of tokens to keep in the context in case we overflow
4852
4953 if (initialPrompt.empty ()) {
5054 initialToken = tokenBos;
5155 initialPrompt = {&initialToken, 1 };
5256 }
5357
54- if (initialPrompt.empty ()) {
55- throw_ex{} << " Empty initial prompt" ;
56- }
57-
58- if (initialPrompt.size () > maxTokens) {
58+ if (initialPrompt.size () > m_state.maxTokens ) {
5959 throw_ex{} << " Initial prompt too long. Got " << initialPrompt.size () << " tokens, max: " << ctxLen - 4 ;
6060 }
6161
@@ -83,6 +83,10 @@ void Session::setInitialPrompt(std::span<const Token> initialPrompt) {
8383}
8484
8585void Session::pushPrompt (std::span<const Token> prompt) {
86+ if (m_state.m_phase != State::Phase::Generating) {
87+ throw_ex{} << " Session hasn't started yet" ;
88+ }
89+
8690 if (!prompt.empty ()) {
8791 auto & sampler = m_instance.sampler ();
8892 auto & model = m_instance.model ();
@@ -101,23 +105,33 @@ void Session::pushPrompt(std::span<const Token> prompt) {
101105}
102106
103107Token Session::getToken () {
108+ if (m_state.m_phase != State::Phase::Generating) {
109+ throw_ex{} << " Session hasn't started yet" ;
110+ }
111+
112+ if (m_state.m_currToken != Token_Invalid) {
113+ // first yield, then decode, thus we don't decode if the session is aborted
114+ doDecode ({&m_state.m_currToken , 1 }, Source::Generated);
115+ }
116+
104117 auto & sampler = m_instance.sampler ();
105118 auto & vocab = m_instance.model ().vocab ();
106119
107- auto token = sampler.sample (m_instance.ctx ());
120+ m_state. m_currToken = sampler.sample (m_instance.ctx ());
108121
109- if (vocab.isEog (token)) {
110- return Token_Invalid;
122+ if (vocab.isEog (m_state.m_currToken )) {
111123 // don't decode eog tokens in case the the interaction is continued
124+ m_state.m_currToken = Token_Invalid;
112125 }
113126
114- // old-comment
115- // first yield, then decode, thus we don't decode if the session is aborted
116- doDecode ({&token, 1 }, Source::Generated);
117- return token;
127+ return m_state.m_currToken ;
118128}
119129
120130std::vector<uint8_t > Session::getState () {
131+ if (m_state.m_phase != State::Phase::Generating) {
132+ throw_ex{} << " Session hasn't started yet" ;
133+ }
134+
121135 const auto size = llama_state_get_size (m_instance.ctx ());
122136 std::vector<uint8_t > state (size);
123137 if (llama_state_get_data (m_instance.ctx (), state.data (), size) != size) {
@@ -127,6 +141,10 @@ std::vector<uint8_t> Session::getState() {
127141}
128142
129143bool Session::setState (std::span<uint8_t > state) {
144+ if (m_state.m_phase != State::Phase::Initial) {
145+ throw_ex{} << " Session already started" ;
146+ }
147+
130148 if (llama_state_set_data (m_instance.ctx (), state.data (), state.size ()) != state.size ()) {
131149 throw_ex{} << " Failed to set state" ;
132150 }
@@ -141,9 +159,9 @@ void Session::doDecode(std::span<const Token> tokens, Source src) {
141159 auto & sampler = m_instance.sampler ();
142160
143161 // Ensure the input doesn't exceed the context size by truncating embd if necessary.
144- if (tokens.size () > maxTokens) {
145- const auto skipped = tokens.size () - maxTokens;
146- tokens = tokens.first (maxTokens);
162+ if (tokens.size () > m_state. maxTokens ) {
163+ const auto skipped = tokens.size () - m_state. maxTokens ;
164+ tokens = tokens.first (m_state. maxTokens );
147165 LLAMA_LOG (Warning, " Input too long. Skipping " , skipped, " tokens" );
148166 }
149167
@@ -153,49 +171,49 @@ void Session::doDecode(std::span<const Token> tokens, Source src) {
153171 // if we run out of context:
154172 // - take the n_keep first tokens from the original prompt (via numPast)
155173 // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
156- const auto num = numPast + tokens.size ();
174+ const auto num = m_state. numPast + tokens.size ();
157175 if (num >= ctxLen) {
158176 if (!m_params.infiniteContext ) {
159177 throw_ex{} << " context limit of " << ctxLen << " reached" ;
160178 }
161179
162- const auto numLeft = numPast - numKeep;
180+ const auto numLeft = m_state. numPast - m_state. numKeep ;
163181 const int numDiscard = numLeft / 2 ; // somewhat arbitrary
164182
165- LLAMA_LOG (Debug, " Context is full. Swapping: past = " , numPast, " , numLeft: " , numLeft,
166- " , ctxLen: " , ctxLen, " , numKeep: " , numKeep, " , numDiscard: " , numDiscard);
183+ LLAMA_LOG (Debug, " Context is full. Swapping: past = " , m_state. numPast , " , numLeft: " , numLeft,
184+ " , ctxLen: " , ctxLen, " , numKeep: " , m_state. numKeep , " , numDiscard: " , numDiscard);
167185
168- llama_kv_cache_seq_rm (lctx, 0 , numKeep, numKeep + numDiscard);
169- llama_kv_cache_seq_add (lctx, 0 , numKeep + numDiscard, numPast, -numDiscard);
186+ llama_kv_cache_seq_rm (lctx, 0 , m_state. numKeep , m_state. numKeep + numDiscard);
187+ llama_kv_cache_seq_add (lctx, 0 , m_state. numKeep + numDiscard, m_state. numPast , -numDiscard);
170188
171- numPast -= numDiscard;
189+ m_state. numPast -= numDiscard;
172190 haveFullContextMitigation = true ;
173191 }
174192 }
175193 else {
176194 const uint32_t gaWidth = m_params.gaWidth ;
177195
178- while (numPast >= gaIndex + gaWidth) {
196+ while (m_state. numPast >= m_state. gaIndex + gaWidth) {
179197 // context extension via Self-Extend
180- const int ib = (gaFactor * gaIndex) / gaWidth;
198+ const int ib = (gaFactor * m_state. gaIndex ) / gaWidth;
181199 const int bd = (gaWidth / gaFactor) * (gaFactor - 1 );
182200 const int dd = (gaWidth / gaFactor) - ib * bd - gaWidth;
183201
184202 LLAMA_LOG (Debug, " Group attention shift: ib = " , ib, " , bd = " , bd, " , dd = " , dd);
185203
186- llama_kv_cache_seq_add (lctx, 0 , gaIndex, numPast, ib * bd);
187- llama_kv_cache_seq_div (lctx, 0 , gaIndex + ib * bd, gaIndex + ib * bd + gaWidth, gaFactor);
188- llama_kv_cache_seq_add (lctx, 0 , gaIndex + ib * bd + gaWidth, numPast + ib * bd, dd);
204+ llama_kv_cache_seq_add (lctx, 0 , m_state. gaIndex , m_state. numPast , ib * bd);
205+ llama_kv_cache_seq_div (lctx, 0 , m_state. gaIndex + ib * bd, m_state. gaIndex + ib * bd + gaWidth, gaFactor);
206+ llama_kv_cache_seq_add (lctx, 0 , m_state. gaIndex + ib * bd + gaWidth, m_state. numPast + ib * bd, dd);
189207
190- numPast -= bd;
208+ m_state. numPast -= bd;
191209
192- gaIndex += gaWidth / gaFactor;
210+ m_state. gaIndex += gaWidth / gaFactor;
193211 haveFullContextMitigation = true ;
194212 }
195213 }
196214
197215 if (haveFullContextMitigation) {
198- LLAMA_LOG (Info, " Context full mitigation performed: past = " , numPast, " , tokens = " , tokens.size ());
216+ LLAMA_LOG (Info, " Context full mitigation performed: past = " , m_state. numPast , " , tokens = " , tokens.size ());
199217 }
200218
201219 // add to sampler
@@ -215,7 +233,7 @@ void Session::doDecode(std::span<const Token> tokens, Source src) {
215233 if (llama_decode (lctx, batch) != 0 ) {
216234 throw_ex{} << " Failed to decode tokens" ;
217235 }
218- numPast += uint32_t (batchTokens.size ());
236+ m_state. numPast += uint32_t (batchTokens.size ());
219237 }
220238
221239}
0 commit comments