@@ -21,20 +21,20 @@ llama_batch makeInputBatch(std::span<const Token> tokens) {
2121}
2222}
2323
24- Session::Session (Instance& instance, InitParams params)
24+ Session::Session (Instance& instance, llama_context* ctx, InitParams params)
2525 : m_instance(instance)
26+ , m_ctx(ctx)
2627 , m_params(std::move(params))
2728{
28- auto lctx = m_instance.ctx ();
2929 auto & sampler = m_instance.sampler ();
3030
31- llama_kv_cache_clear (lctx );
32- llama_synchronize (lctx );
33- llama_perf_context_reset (lctx );
31+ llama_kv_cache_clear (m_ctx );
32+ llama_synchronize (m_ctx );
33+ llama_perf_context_reset (m_ctx );
3434 sampler.reset ();
3535 sampler.perfReset ();
3636
37- const auto ctxLen = llama_n_ctx (lctx );
37+ const auto ctxLen = llama_n_ctx (m_ctx );
3838 m_state.maxTokens = ctxLen - 4 ; // (#16)
3939}
4040
@@ -45,8 +45,7 @@ void Session::setInitialPrompt(std::span<const Token> initialPrompt) {
4545
4646 Token initialToken; // used to reset the initial prompt to a single token
4747
48- auto lctx = m_instance.ctx ();
49- const auto ctxLen = llama_n_ctx (lctx);
48+ const auto ctxLen = llama_n_ctx (m_ctx);
5049 const auto tokenBos = llama_token_bos (m_instance.model ().lmodel ());
5150 m_state.numKeep = std::min (uint32_t (initialPrompt.size ()), m_state.maxTokens ); // number of tokens to keep in the context in case we overflow
5251
@@ -70,7 +69,7 @@ void Session::setInitialPrompt(std::span<const Token> initialPrompt) {
7069
7170 if (m_instance.model ().hasEncoder ()) {
7271 auto batch = makeInputBatch (initialPrompt);
73- auto res = llama_encode (lctx , batch);
72+ auto res = llama_encode (m_ctx , batch);
7473 if (res != 0 ) {
7574 throw_ex{} << " Failed to encode input" ;
7675 }
@@ -117,7 +116,7 @@ Token Session::getToken() {
117116 auto & sampler = m_instance.sampler ();
118117 auto & vocab = m_instance.model ().vocab ();
119118
120- m_state.m_currToken = sampler.sample (m_instance. ctx () );
119+ m_state.m_currToken = sampler.sample (m_ctx );
121120
122121 if (vocab.isEog (m_state.m_currToken )) {
123122 // don't decode eog tokens in case the the interaction is continued
@@ -132,9 +131,9 @@ std::vector<uint8_t> Session::getState() {
132131 throw_ex{} << " Session hasn't started yet" ;
133132 }
134133
135- const auto size = llama_state_get_size (m_instance. ctx () );
134+ const auto size = llama_state_get_size (m_ctx );
136135 std::vector<uint8_t > state (size);
137- if (llama_state_get_data (m_instance. ctx () , state.data (), size) != size) {
136+ if (llama_state_get_data (m_ctx , state.data (), size) != size) {
138137 throw_ex{} << " Failed to get state" ;
139138 }
140139 return state;
@@ -145,19 +144,13 @@ bool Session::setState(std::span<uint8_t> state) {
145144 throw_ex{} << " Session already started" ;
146145 }
147146
148- if (llama_state_set_data (m_instance. ctx () , state.data (), state.size ()) != state.size ()) {
147+ if (llama_state_set_data (m_ctx , state.data (), state.size ()) != state.size ()) {
149148 throw_ex{} << " Failed to set state" ;
150149 }
151150 return true ;
152151}
153152
154153void Session::doDecode (std::span<const Token> tokens, Source src) {
155- // first try to expand the context if needed
156- const auto gaFactor = m_params.gaFactor ;
157- auto lctx = m_instance.ctx ();
158- const auto ctxLen = llama_n_ctx (lctx);
159- auto & sampler = m_instance.sampler ();
160-
161154 // Ensure the input doesn't exceed the context size by truncating embd if necessary.
162155 if (tokens.size () > m_state.maxTokens ) {
163156 const auto skipped = tokens.size () - m_state.maxTokens ;
@@ -166,6 +159,10 @@ void Session::doDecode(std::span<const Token> tokens, Source src) {
166159 }
167160
168161 bool haveFullContextMitigation = false ;
162+ const auto gaFactor = m_params.gaFactor ;
163+ const auto ctxLen = llama_n_ctx (m_ctx);
164+ auto & sampler = m_instance.sampler ();
165+
169166 if (gaFactor == 1 ) {
170167 // infinite text generation via context shifting
171168 // if we run out of context:
@@ -183,8 +180,8 @@ void Session::doDecode(std::span<const Token> tokens, Source src) {
183180 LLAMA_LOG (Debug, " Context is full. Swapping: past = " , m_state.numPast , " , numLeft: " , numLeft,
184181 " , ctxLen: " , ctxLen, " , numKeep: " , m_state.numKeep , " , numDiscard: " , numDiscard);
185182
186- llama_kv_cache_seq_rm (lctx , 0 , m_state.numKeep , m_state.numKeep + numDiscard);
187- llama_kv_cache_seq_add (lctx , 0 , m_state.numKeep + numDiscard, m_state.numPast , -numDiscard);
183+ llama_kv_cache_seq_rm (m_ctx , 0 , m_state.numKeep , m_state.numKeep + numDiscard);
184+ llama_kv_cache_seq_add (m_ctx , 0 , m_state.numKeep + numDiscard, m_state.numPast , -numDiscard);
188185
189186 m_state.numPast -= numDiscard;
190187 haveFullContextMitigation = true ;
@@ -201,9 +198,9 @@ void Session::doDecode(std::span<const Token> tokens, Source src) {
201198
202199 LLAMA_LOG (Debug, " Group attention shift: ib = " , ib, " , bd = " , bd, " , dd = " , dd);
203200
204- llama_kv_cache_seq_add (lctx , 0 , m_state.gaIndex , m_state.numPast , ib * bd);
205- llama_kv_cache_seq_div (lctx , 0 , m_state.gaIndex + ib * bd, m_state.gaIndex + ib * bd + gaWidth, gaFactor);
206- llama_kv_cache_seq_add (lctx , 0 , m_state.gaIndex + ib * bd + gaWidth, m_state.numPast + ib * bd, dd);
201+ llama_kv_cache_seq_add (m_ctx , 0 , m_state.gaIndex , m_state.numPast , ib * bd);
202+ llama_kv_cache_seq_div (m_ctx , 0 , m_state.gaIndex + ib * bd, m_state.gaIndex + ib * bd + gaWidth, gaFactor);
203+ llama_kv_cache_seq_add (m_ctx , 0 , m_state.gaIndex + ib * bd + gaWidth, m_state.numPast + ib * bd, dd);
207204
208205 m_state.numPast -= bd;
209206
@@ -223,14 +220,14 @@ void Session::doDecode(std::span<const Token> tokens, Source src) {
223220 }
224221
225222 // decode
226- const auto batchSize = llama_n_batch (lctx );
223+ const auto batchSize = llama_n_batch (m_ctx );
227224
228225 // decode with batches of batchSize
229226 while (!tokens.empty ()) {
230227 auto batchTokens = tokens.size () > batchSize ? tokens.first (batchSize) : tokens;
231228 tokens = tokens.subspan (batchTokens.size ());
232229 auto batch = makeInputBatch (batchTokens);
233- if (llama_decode (lctx , batch) != 0 ) {
230+ if (llama_decode (m_ctx , batch) != 0 ) {
234231 throw_ex{} << " Failed to decode tokens" ;
235232 }
236233 m_state.numPast += uint32_t (batchTokens.size ());
0 commit comments