@@ -116,230 +116,9 @@ void Instance::warmup() {
116116 llama_perf_context_reset (lctx);
117117}
118118
119- Session Instance::newSession (const SessionParams params) {
119+ Session Instance::newSession (const Session::InitParams params) {
120120 // not a real await as we return suspend_always initially
121- auto op = co_await Session::Prompt{};
122-
123- if (m_hasActiveSession) {
124- throw_ex{} << " Instance already has an active session" ;
125- }
126-
127- if (op.type != Session::SessionOpData::OpType::Prompt && op.type != Session::SessionOpData::OpType::SetState) {
128- throw_ex{} << " Invalid initial session operation type" ;
129- }
130-
131- m_hasActiveSession = true ;
132- astl::sentry closeSessionSentry ([this ] { m_hasActiveSession = false ; });
133-
134- auto lctx = m_lctx.get ();
135- auto & vocab = m_model.vocab ();
136-
137- llama_kv_cache_clear (lctx);
138- llama_synchronize (lctx);
139- llama_perf_context_reset (lctx);
140- m_sampler.reset ();
141- m_sampler.perfReset ();
142-
143- std::vector<llama_token> sessionTokens;
144- const auto tokenBos = llama_token_bos (m_model.lmodel ());
145- const auto ctxLen = llama_n_ctx (lctx);
146- const auto maxTokens = ctxLen - 4 ; // (#16)
147- auto numKeep = llama_get_kv_cache_token_count (lctx);
148-
149- if (op.type == Session::SessionOpData::OpType::Prompt) {
150- Token initialToken; // used to reset the initial prompt to a single token
151- auto & initialPrompt = op.pendingPrompt ;
152- numKeep = std::min (uint32_t (initialPrompt.size ()), maxTokens); // number of tokens to keep in the context in case we overflow
153-
154- if (initialPrompt.empty ()) {
155- initialToken = tokenBos;
156- initialPrompt = {&initialToken, 1 };
157- }
158-
159- if (initialPrompt.empty ()) {
160- throw_ex{} << " Empty initial prompt" ;
161- }
162-
163- if (initialPrompt.size () > maxTokens) {
164- throw_ex{} << " Initial prompt too long. Got " << initialPrompt.size () << " tokens, max: " << ctxLen - 4 ;
165- }
166-
167- if (params.gaFactor != 1 ) {
168- const uint32_t gaFactor = params.gaFactor ;
169- const uint32_t gaWidth = params.gaWidth ;
170- if (gaWidth % gaFactor != 0 ) {
171- throw_ex{} << " Group-attention width " << gaWidth << " must be a multiple of group-attention factor " << gaFactor;
172- }
173- LLAMA_LOG (Info, " self-extend: train = " , m_model.trainCtxLength (), " , gaFactor = " , gaFactor, " , gaWidth = " , gaWidth);
174- }
175-
176- if (m_model.hasEncoder ()) {
177- auto batch = makeInputBatch (initialPrompt);
178- auto res = llama_encode (lctx, batch);
179- if (res != 0 ) {
180- throw_ex{} << " Failed to encode input" ;
181- }
182- initialToken = vocab.decoderStartToken ();
183- initialPrompt = {&initialToken, 1 };
184- }
185- } else {
186- if (llama_state_set_data (lctx, op.state .data (), op.state .size ()) != op.state .size ()) {
187- throw_ex{} << " Failed to set state" ;
188- }
189- }
190-
191- // group attention state
192- uint32_t gaIndex = 0 ; // number of grouped KV tokens (only used if params.gaFactor > 1)
193- uint32_t numPast = 0 ; // number of tokens in the context (that's prompts + generated)
194-
195- enum class Source {
196- InitialPrompt,
197- InteractivePrompt,
198- Generated
199- };
200-
201- auto doDecode = [&](std::span<const Token> tokens, Source src) {
202- // first try to expand the context if needed
203- const auto gaFactor = params.gaFactor ;
204-
205- // Ensure the input doesn't exceed the context size by truncating embd if necessary.
206- if (tokens.size () > maxTokens) {
207- const auto skipped = tokens.size () - maxTokens;
208- tokens = tokens.first (maxTokens);
209- LLAMA_LOG (Warning, " Input too long. Skipping " , skipped, " tokens" );
210- }
211-
212- bool haveFullContextMitigation = false ;
213- if (gaFactor == 1 ) {
214- // infinite text generation via context shifting
215- // if we run out of context:
216- // - take the n_keep first tokens from the original prompt (via numPast)
217- // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
218- const auto num = numPast + tokens.size ();
219- if (num >= ctxLen) {
220- if (!params.infiniteContext ) {
221- throw_ex{} << " context limit of " << ctxLen << " reached" ;
222- }
223-
224- const auto numLeft = numPast - numKeep;
225- const int numDiscard = numLeft / 2 ; // somewhat arbitrary
226-
227- LLAMA_LOG (Debug, " Context is full. Swapping: past = " , numPast, " , numLeft: " , numLeft,
228- " , ctxLen: " , ctxLen, " , numKeep: " , numKeep, " , numDiscard: " , numDiscard);
229-
230- llama_kv_cache_seq_rm (lctx, 0 , numKeep, numKeep + numDiscard);
231- llama_kv_cache_seq_add (lctx, 0 , numKeep + numDiscard, numPast, -numDiscard);
232-
233- numPast -= numDiscard;
234- haveFullContextMitigation = true ;
235- }
236- }
237- else {
238- const uint32_t gaWidth = params.gaWidth ;
239-
240- while (numPast >= gaIndex + gaWidth) {
241- // context extension via Self-Extend
242- const int ib = (gaFactor * gaIndex) / gaWidth;
243- const int bd = (gaWidth / gaFactor) * (gaFactor - 1 );
244- const int dd = (gaWidth / gaFactor) - ib * bd - gaWidth;
245-
246- LLAMA_LOG (Debug, " Group attention shift: ib = " , ib, " , bd = " , bd, " , dd = " , dd);
247-
248- llama_kv_cache_seq_add (lctx, 0 , gaIndex, numPast, ib * bd);
249- llama_kv_cache_seq_div (lctx, 0 , gaIndex + ib * bd, gaIndex + ib * bd + gaWidth, gaFactor);
250- llama_kv_cache_seq_add (lctx, 0 , gaIndex + ib * bd + gaWidth, numPast + ib * bd, dd);
251-
252- numPast -= bd;
253-
254- gaIndex += gaWidth / gaFactor;
255- haveFullContextMitigation = true ;
256- }
257- }
258-
259- if (haveFullContextMitigation) {
260- LLAMA_LOG (Info, " Context full mitigation performed: past = " , numPast, " , tokens = " , tokens.size ());
261- }
262-
263- // add to sampler
264- for (auto t : tokens) {
265- // only apply grammar for generated content
266- m_sampler.accept (t, src == Source::Generated);
267- }
268-
269- // decode
270- const auto batchSize = llama_n_batch (lctx);
271-
272- // decode with batches of batchSize
273- while (!tokens.empty ()) {
274- auto batchTokens = tokens.size () > batchSize ? tokens.first (batchSize) : tokens;
275- tokens = tokens.subspan (batchTokens.size ());
276- auto batch = makeInputBatch (batchTokens);
277- if (llama_decode (lctx, batch) != 0 ) {
278- throw_ex{} << " Failed to decode tokens" ;
279- }
280- numPast += uint32_t (batchTokens.size ());
281- }
282- };
283-
284- if (op.type == Session::SessionOpData::OpType::Prompt) {
285- doDecode (op.pendingPrompt , Source::InitialPrompt);
286-
287- co_await Session::StartGeneration{}; // suspend pre generation
288- } else {
289- // set the state
290- co_yield true ;
291- }
292-
293- while (true ) {
294- auto currOp = co_await Session::Prompt{};
295-
296- if (currOp.type == Session::SessionOpData::OpType::GetState) {
297- // get the state
298- const auto size = llama_state_get_size (m_lctx.get ());
299- std::vector<uint8_t > state (size);
300- if (llama_state_get_data (m_lctx.get (), state.data (), size) != size) {
301- throw_ex{} << " Failed to get state" ;
302- }
303- co_yield state;
304- continue ;
305- } else if (currOp.type == Session::SessionOpData::OpType::SetState) {
306- auto & state = currOp.state ;
307- if (llama_state_set_data (m_lctx.get (), state.data (), state.size ()) != state.size ()) {
308- throw_ex{} << " Failed to set state" ;
309- }
310- co_yield true ;
311- continue ;
312- } else if (currOp.type == Session::SessionOpData::OpType::Prompt) {
313- auto & prompt = currOp.pendingPrompt ;
314- if (!prompt.empty ()) {
315-
316- // reset sampling and don't allow previous inputs to affect the generation
317- m_sampler.reset ();
318-
319- if (m_model.prefixInputsWithBos ()) {
320- // add bos token to the prompt
321- doDecode ({&tokenBos, 1 }, Source::InteractivePrompt);
322- }
323-
324- doDecode (prompt, Source::InteractivePrompt);
325- }
326-
327- auto token = m_sampler.sample (lctx);
328- sessionTokens.push_back (token);
329- if (vocab.isEog (token)) {
330- co_yield Token_Invalid;
331- // don't decode eog tokens in case the the interaction is continued
332- }
333- else {
334- // first yield, then decode, thus we don't decode if the session is aborted
335- co_yield token;
336- doDecode ({&token, 1 }, Source::Generated);
337- }
338- } else {
339- LLAMA_LOG (Error, " Unrecognized session operation type" );
340- }
341-
342- }
121+ return Session (*this , params);
343122}
344123
345124} // namespace ac::llama
0 commit comments