Skip to content

Commit adbfd75

Browse files
committed
refactor: change session's co-routine with stateful class, ref #17
1 parent 2d6b51b commit adbfd75

File tree

7 files changed

+400
-661
lines changed

7 files changed

+400
-661
lines changed

code/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ target_sources(ac-llama PRIVATE
2828
ac/llama/Instance.hpp
2929
ac/llama/Instance.cpp
3030
ac/llama/Session.hpp
31+
ac/llama/Session.cpp
3132
ac/llama/AntipromptManager.hpp
3233
ac/llama/AntipromptManager.cpp
3334
ac/llama/IncrementalStringFinder.hpp

code/ac/llama/Instance.cpp

Lines changed: 2 additions & 223 deletions
Original file line numberDiff line numberDiff line change
@@ -116,230 +116,9 @@ void Instance::warmup() {
116116
llama_perf_context_reset(lctx);
117117
}
118118

119-
Session Instance::newSession(const SessionParams params) {
119+
Session Instance::newSession(const Session::InitParams params) {
120120
// not a real await as we return suspend_always initially
121-
auto op = co_await Session::Prompt{};
122-
123-
if (m_hasActiveSession) {
124-
throw_ex{} << "Instance already has an active session";
125-
}
126-
127-
if (op.type != Session::SessionOpData::OpType::Prompt && op.type != Session::SessionOpData::OpType::SetState) {
128-
throw_ex{} << "Invalid initial session operation type";
129-
}
130-
131-
m_hasActiveSession = true;
132-
astl::sentry closeSessionSentry([this] { m_hasActiveSession = false; });
133-
134-
auto lctx = m_lctx.get();
135-
auto& vocab = m_model.vocab();
136-
137-
llama_kv_cache_clear(lctx);
138-
llama_synchronize(lctx);
139-
llama_perf_context_reset(lctx);
140-
m_sampler.reset();
141-
m_sampler.perfReset();
142-
143-
std::vector<llama_token> sessionTokens;
144-
const auto tokenBos = llama_token_bos(m_model.lmodel());
145-
const auto ctxLen = llama_n_ctx(lctx);
146-
const auto maxTokens = ctxLen - 4; // (#16)
147-
auto numKeep = llama_get_kv_cache_token_count(lctx);
148-
149-
if (op.type == Session::SessionOpData::OpType::Prompt) {
150-
Token initialToken; // used to reset the initial prompt to a single token
151-
auto& initialPrompt = op.pendingPrompt;
152-
numKeep = std::min(uint32_t(initialPrompt.size()), maxTokens); // number of tokens to keep in the context in case we overflow
153-
154-
if (initialPrompt.empty()) {
155-
initialToken = tokenBos;
156-
initialPrompt = {&initialToken, 1};
157-
}
158-
159-
if (initialPrompt.empty()) {
160-
throw_ex{} << "Empty initial prompt";
161-
}
162-
163-
if (initialPrompt.size() > maxTokens) {
164-
throw_ex{} << "Initial prompt too long. Got " << initialPrompt.size() << " tokens, max: " << ctxLen - 4;
165-
}
166-
167-
if (params.gaFactor != 1) {
168-
const uint32_t gaFactor = params.gaFactor;
169-
const uint32_t gaWidth = params.gaWidth;
170-
if (gaWidth % gaFactor != 0) {
171-
throw_ex{} << "Group-attention width " << gaWidth << " must be a multiple of group-attention factor " << gaFactor;
172-
}
173-
LLAMA_LOG(Info, "self-extend: train = ", m_model.trainCtxLength(), ", gaFactor = ", gaFactor, ", gaWidth = ", gaWidth);
174-
}
175-
176-
if (m_model.hasEncoder()) {
177-
auto batch = makeInputBatch(initialPrompt);
178-
auto res = llama_encode(lctx, batch);
179-
if (res != 0) {
180-
throw_ex{} << "Failed to encode input";
181-
}
182-
initialToken = vocab.decoderStartToken();
183-
initialPrompt = {&initialToken, 1};
184-
}
185-
} else {
186-
if (llama_state_set_data(lctx, op.state.data(), op.state.size()) != op.state.size()) {
187-
throw_ex{} << "Failed to set state";
188-
}
189-
}
190-
191-
// group attention state
192-
uint32_t gaIndex = 0; // number of grouped KV tokens (only used if params.gaFactor > 1)
193-
uint32_t numPast = 0; // number of tokens in the context (that's prompts + generated)
194-
195-
enum class Source {
196-
InitialPrompt,
197-
InteractivePrompt,
198-
Generated
199-
};
200-
201-
auto doDecode = [&](std::span<const Token> tokens, Source src) {
202-
// first try to expand the context if needed
203-
const auto gaFactor = params.gaFactor;
204-
205-
// Ensure the input doesn't exceed the context size by truncating embd if necessary.
206-
if (tokens.size() > maxTokens) {
207-
const auto skipped = tokens.size() - maxTokens;
208-
tokens = tokens.first(maxTokens);
209-
LLAMA_LOG(Warning, "Input too long. Skipping ", skipped, " tokens");
210-
}
211-
212-
bool haveFullContextMitigation = false;
213-
if (gaFactor == 1) {
214-
// infinite text generation via context shifting
215-
// if we run out of context:
216-
// - take the n_keep first tokens from the original prompt (via numPast)
217-
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
218-
const auto num = numPast + tokens.size();
219-
if (num >= ctxLen) {
220-
if (!params.infiniteContext) {
221-
throw_ex{} << "context limit of " << ctxLen << " reached";
222-
}
223-
224-
const auto numLeft = numPast - numKeep;
225-
const int numDiscard = numLeft / 2; // somewhat arbitrary
226-
227-
LLAMA_LOG(Debug, "Context is full. Swapping: past = ", numPast, ", numLeft: ", numLeft,
228-
", ctxLen: ", ctxLen, ", numKeep: ", numKeep, ", numDiscard: ", numDiscard);
229-
230-
llama_kv_cache_seq_rm(lctx, 0, numKeep, numKeep + numDiscard);
231-
llama_kv_cache_seq_add(lctx, 0, numKeep + numDiscard, numPast, -numDiscard);
232-
233-
numPast -= numDiscard;
234-
haveFullContextMitigation = true;
235-
}
236-
}
237-
else {
238-
const uint32_t gaWidth = params.gaWidth;
239-
240-
while (numPast >= gaIndex + gaWidth) {
241-
// context extension via Self-Extend
242-
const int ib = (gaFactor * gaIndex) / gaWidth;
243-
const int bd = (gaWidth / gaFactor) * (gaFactor - 1);
244-
const int dd = (gaWidth / gaFactor) - ib * bd - gaWidth;
245-
246-
LLAMA_LOG(Debug, "Group attention shift: ib = ", ib, ", bd = ", bd, ", dd = ", dd);
247-
248-
llama_kv_cache_seq_add(lctx, 0, gaIndex, numPast, ib * bd);
249-
llama_kv_cache_seq_div(lctx, 0, gaIndex + ib * bd, gaIndex + ib * bd + gaWidth, gaFactor);
250-
llama_kv_cache_seq_add(lctx, 0, gaIndex + ib * bd + gaWidth, numPast + ib * bd, dd);
251-
252-
numPast -= bd;
253-
254-
gaIndex += gaWidth / gaFactor;
255-
haveFullContextMitigation = true;
256-
}
257-
}
258-
259-
if (haveFullContextMitigation) {
260-
LLAMA_LOG(Info, "Context full mitigation performed: past = ", numPast, ", tokens = ", tokens.size());
261-
}
262-
263-
// add to sampler
264-
for (auto t : tokens) {
265-
// only apply grammar for generated content
266-
m_sampler.accept(t, src == Source::Generated);
267-
}
268-
269-
// decode
270-
const auto batchSize = llama_n_batch(lctx);
271-
272-
// decode with batches of batchSize
273-
while (!tokens.empty()) {
274-
auto batchTokens = tokens.size() > batchSize ? tokens.first(batchSize) : tokens;
275-
tokens = tokens.subspan(batchTokens.size());
276-
auto batch = makeInputBatch(batchTokens);
277-
if (llama_decode(lctx, batch) != 0) {
278-
throw_ex{} << "Failed to decode tokens";
279-
}
280-
numPast += uint32_t(batchTokens.size());
281-
}
282-
};
283-
284-
if (op.type == Session::SessionOpData::OpType::Prompt) {
285-
doDecode(op.pendingPrompt, Source::InitialPrompt);
286-
287-
co_await Session::StartGeneration{}; // suspend pre generation
288-
} else {
289-
// set the state
290-
co_yield true;
291-
}
292-
293-
while (true) {
294-
auto currOp = co_await Session::Prompt{};
295-
296-
if (currOp.type == Session::SessionOpData::OpType::GetState) {
297-
// get the state
298-
const auto size = llama_state_get_size(m_lctx.get());
299-
std::vector<uint8_t> state(size);
300-
if (llama_state_get_data(m_lctx.get(), state.data(), size) != size) {
301-
throw_ex{} << "Failed to get state";
302-
}
303-
co_yield state;
304-
continue;
305-
} else if (currOp.type == Session::SessionOpData::OpType::SetState) {
306-
auto& state = currOp.state;
307-
if (llama_state_set_data(m_lctx.get(), state.data(), state.size()) != state.size()) {
308-
throw_ex{} << "Failed to set state";
309-
}
310-
co_yield true;
311-
continue;
312-
} else if (currOp.type == Session::SessionOpData::OpType::Prompt) {
313-
auto& prompt = currOp.pendingPrompt;
314-
if (!prompt.empty()) {
315-
316-
// reset sampling and don't allow previous inputs to affect the generation
317-
m_sampler.reset();
318-
319-
if (m_model.prefixInputsWithBos()) {
320-
// add bos token to the prompt
321-
doDecode({&tokenBos, 1}, Source::InteractivePrompt);
322-
}
323-
324-
doDecode(prompt, Source::InteractivePrompt);
325-
}
326-
327-
auto token = m_sampler.sample(lctx);
328-
sessionTokens.push_back(token);
329-
if (vocab.isEog(token)) {
330-
co_yield Token_Invalid;
331-
// don't decode eog tokens in case the the interaction is continued
332-
}
333-
else {
334-
// first yield, then decode, thus we don't decode if the session is aborted
335-
co_yield token;
336-
doDecode({&token, 1}, Source::Generated);
337-
}
338-
} else {
339-
LLAMA_LOG(Error, "Unrecognized session operation type");
340-
}
341-
342-
}
121+
return Session(*this, params);
343122
}
344123

345124
} // namespace ac::llama

code/ac/llama/Instance.hpp

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#pragma once
55
#include "export.h"
66
#include "Sampler.hpp"
7+
#include "Session.hpp"
78
#include <astl/mem_ext.hpp>
89

910
struct llama_context;
@@ -32,20 +33,12 @@ class AC_LLAMA_EXPORT Instance {
3233
// do an empty model run to load model data in cache
3334
void warmup();
3435

35-
struct SessionParams {
36-
uint32_t gaFactor = 1; // group-attention factor
37-
uint32_t gaWidth = 512; // group-attention width
38-
39-
// if true, the inference tries to extend the context by truncating previous tokens
40-
// only used if gaFactor == 1
41-
bool infiniteContext = true;
42-
};
43-
4436
// only one session per instance can be active at a time
45-
Session newSession(const SessionParams params);
37+
Session newSession(const Session::InitParams params);
4638

4739
const Model& model() const noexcept { return m_model; }
48-
const Sampler& sampler() const noexcept { return m_sampler; }
40+
llama_context* ctx() const noexcept { return m_lctx.get(); }
41+
Sampler& sampler() noexcept { return m_sampler; }
4942

5043
private:
5144
Model& m_model;

0 commit comments

Comments
 (0)