Skip to content

Commit c73ae32

Browse files
committed
fix: move whole state to a struct, ref #17
1 parent adbfd75 commit c73ae32

File tree

2 files changed

+67
-37
lines changed

2 files changed

+67
-37
lines changed

code/ac/llama/Session.cpp

Lines changed: 51 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -35,27 +35,27 @@ Session::Session(Instance& instance, InitParams params)
3535
sampler.perfReset();
3636

3737
const auto ctxLen = llama_n_ctx(lctx);
38-
maxTokens = ctxLen - 4; // (#16)
38+
m_state.maxTokens = ctxLen - 4; // (#16)
3939
}
4040

4141
void Session::setInitialPrompt(std::span<const Token> initialPrompt) {
42+
if (m_state.m_phase != State::Phase::Initial) {
43+
throw_ex{} << "Session already started";
44+
}
45+
4246
Token initialToken; // used to reset the initial prompt to a single token
4347

4448
auto lctx = m_instance.ctx();
4549
const auto ctxLen = llama_n_ctx(lctx);
4650
const auto tokenBos = llama_token_bos(m_instance.model().lmodel());
47-
numKeep = std::min(uint32_t(initialPrompt.size()), maxTokens); // number of tokens to keep in the context in case we overflow
51+
m_state.numKeep = std::min(uint32_t(initialPrompt.size()), m_state.maxTokens); // number of tokens to keep in the context in case we overflow
4852

4953
if (initialPrompt.empty()) {
5054
initialToken = tokenBos;
5155
initialPrompt = {&initialToken, 1};
5256
}
5357

54-
if (initialPrompt.empty()) {
55-
throw_ex{} << "Empty initial prompt";
56-
}
57-
58-
if (initialPrompt.size() > maxTokens) {
58+
if (initialPrompt.size() > m_state.maxTokens) {
5959
throw_ex{} << "Initial prompt too long. Got " << initialPrompt.size() << " tokens, max: " << ctxLen - 4;
6060
}
6161

@@ -83,6 +83,10 @@ void Session::setInitialPrompt(std::span<const Token> initialPrompt) {
8383
}
8484

8585
void Session::pushPrompt(std::span<const Token> prompt) {
86+
if (m_state.m_phase != State::Phase::Generating) {
87+
throw_ex{} << "Session hasn't started yet";
88+
}
89+
8690
if (!prompt.empty()) {
8791
auto& sampler = m_instance.sampler();
8892
auto& model = m_instance.model();
@@ -101,23 +105,33 @@ void Session::pushPrompt(std::span<const Token> prompt) {
101105
}
102106

103107
Token Session::getToken() {
108+
if (m_state.m_phase != State::Phase::Generating) {
109+
throw_ex{} << "Session hasn't started yet";
110+
}
111+
112+
if (m_state.m_currToken != Token_Invalid) {
113+
// first yield, then decode, thus we don't decode if the session is aborted
114+
doDecode({&m_state.m_currToken, 1}, Source::Generated);
115+
}
116+
104117
auto& sampler = m_instance.sampler();
105118
auto& vocab = m_instance.model().vocab();
106119

107-
auto token = sampler.sample(m_instance.ctx());
120+
m_state.m_currToken = sampler.sample(m_instance.ctx());
108121

109-
if (vocab.isEog(token)) {
110-
return Token_Invalid;
122+
if (vocab.isEog(m_state.m_currToken)) {
111123
// don't decode eog tokens in case the the interaction is continued
124+
m_state.m_currToken = Token_Invalid;
112125
}
113126

114-
// old-comment
115-
// first yield, then decode, thus we don't decode if the session is aborted
116-
doDecode({&token, 1}, Source::Generated);
117-
return token;
127+
return m_state.m_currToken;
118128
}
119129

120130
std::vector<uint8_t> Session::getState() {
131+
if (m_state.m_phase != State::Phase::Generating) {
132+
throw_ex{} << "Session hasn't started yet";
133+
}
134+
121135
const auto size = llama_state_get_size(m_instance.ctx());
122136
std::vector<uint8_t> state(size);
123137
if (llama_state_get_data(m_instance.ctx(), state.data(), size) != size) {
@@ -127,6 +141,10 @@ std::vector<uint8_t> Session::getState() {
127141
}
128142

129143
bool Session::setState(std::span<uint8_t> state) {
144+
if (m_state.m_phase != State::Phase::Initial) {
145+
throw_ex{} << "Session already started";
146+
}
147+
130148
if (llama_state_set_data(m_instance.ctx(), state.data(), state.size()) != state.size()) {
131149
throw_ex{} << "Failed to set state";
132150
}
@@ -141,9 +159,9 @@ void Session::doDecode(std::span<const Token> tokens, Source src) {
141159
auto& sampler = m_instance.sampler();
142160

143161
// Ensure the input doesn't exceed the context size by truncating embd if necessary.
144-
if (tokens.size() > maxTokens) {
145-
const auto skipped = tokens.size() - maxTokens;
146-
tokens = tokens.first(maxTokens);
162+
if (tokens.size() > m_state.maxTokens) {
163+
const auto skipped = tokens.size() - m_state.maxTokens;
164+
tokens = tokens.first(m_state.maxTokens);
147165
LLAMA_LOG(Warning, "Input too long. Skipping ", skipped, " tokens");
148166
}
149167

@@ -153,49 +171,49 @@ void Session::doDecode(std::span<const Token> tokens, Source src) {
153171
// if we run out of context:
154172
// - take the n_keep first tokens from the original prompt (via numPast)
155173
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
156-
const auto num = numPast + tokens.size();
174+
const auto num = m_state.numPast + tokens.size();
157175
if (num >= ctxLen) {
158176
if (!m_params.infiniteContext) {
159177
throw_ex{} << "context limit of " << ctxLen << " reached";
160178
}
161179

162-
const auto numLeft = numPast - numKeep;
180+
const auto numLeft = m_state.numPast - m_state.numKeep;
163181
const int numDiscard = numLeft / 2; // somewhat arbitrary
164182

165-
LLAMA_LOG(Debug, "Context is full. Swapping: past = ", numPast, ", numLeft: ", numLeft,
166-
", ctxLen: ", ctxLen, ", numKeep: ", numKeep, ", numDiscard: ", numDiscard);
183+
LLAMA_LOG(Debug, "Context is full. Swapping: past = ", m_state.numPast, ", numLeft: ", numLeft,
184+
", ctxLen: ", ctxLen, ", numKeep: ", m_state.numKeep, ", numDiscard: ", numDiscard);
167185

168-
llama_kv_cache_seq_rm(lctx, 0, numKeep, numKeep + numDiscard);
169-
llama_kv_cache_seq_add(lctx, 0, numKeep + numDiscard, numPast, -numDiscard);
186+
llama_kv_cache_seq_rm(lctx, 0, m_state.numKeep, m_state.numKeep + numDiscard);
187+
llama_kv_cache_seq_add(lctx, 0, m_state.numKeep + numDiscard, m_state.numPast, -numDiscard);
170188

171-
numPast -= numDiscard;
189+
m_state.numPast -= numDiscard;
172190
haveFullContextMitigation = true;
173191
}
174192
}
175193
else {
176194
const uint32_t gaWidth = m_params.gaWidth;
177195

178-
while (numPast >= gaIndex + gaWidth) {
196+
while (m_state.numPast >= m_state.gaIndex + gaWidth) {
179197
// context extension via Self-Extend
180-
const int ib = (gaFactor * gaIndex) / gaWidth;
198+
const int ib = (gaFactor * m_state.gaIndex) / gaWidth;
181199
const int bd = (gaWidth / gaFactor) * (gaFactor - 1);
182200
const int dd = (gaWidth / gaFactor) - ib * bd - gaWidth;
183201

184202
LLAMA_LOG(Debug, "Group attention shift: ib = ", ib, ", bd = ", bd, ", dd = ", dd);
185203

186-
llama_kv_cache_seq_add(lctx, 0, gaIndex, numPast, ib * bd);
187-
llama_kv_cache_seq_div(lctx, 0, gaIndex + ib * bd, gaIndex + ib * bd + gaWidth, gaFactor);
188-
llama_kv_cache_seq_add(lctx, 0, gaIndex + ib * bd + gaWidth, numPast + ib * bd, dd);
204+
llama_kv_cache_seq_add(lctx, 0, m_state.gaIndex, m_state.numPast, ib * bd);
205+
llama_kv_cache_seq_div(lctx, 0, m_state.gaIndex + ib * bd, m_state.gaIndex + ib * bd + gaWidth, gaFactor);
206+
llama_kv_cache_seq_add(lctx, 0, m_state.gaIndex + ib * bd + gaWidth, m_state.numPast + ib * bd, dd);
189207

190-
numPast -= bd;
208+
m_state.numPast -= bd;
191209

192-
gaIndex += gaWidth / gaFactor;
210+
m_state.gaIndex += gaWidth / gaFactor;
193211
haveFullContextMitigation = true;
194212
}
195213
}
196214

197215
if (haveFullContextMitigation) {
198-
LLAMA_LOG(Info, "Context full mitigation performed: past = ", numPast, ", tokens = ", tokens.size());
216+
LLAMA_LOG(Info, "Context full mitigation performed: past = ", m_state.numPast, ", tokens = ", tokens.size());
199217
}
200218

201219
// add to sampler
@@ -215,7 +233,7 @@ void Session::doDecode(std::span<const Token> tokens, Source src) {
215233
if (llama_decode(lctx, batch) != 0) {
216234
throw_ex{} << "Failed to decode tokens";
217235
}
218-
numPast += uint32_t(batchTokens.size());
236+
m_state.numPast += uint32_t(batchTokens.size());
219237
}
220238

221239
}

code/ac/llama/Session.hpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,24 @@ class Session {
4040

4141
void doDecode(std::span<const Token> tokens, Source src);
4242

43+
struct State {
44+
enum class Phase {
45+
Initial,
46+
Generating
47+
};
48+
49+
Phase m_phase = Phase::Initial;
50+
Token m_currToken = Token_Invalid;
51+
52+
unsigned maxTokens = 0;
53+
unsigned numKeep = 0;
54+
uint32_t gaIndex = 0; // number of grouped KV tokens (only used if params.gaFactor > 1)
55+
uint32_t numPast = 0; // number of tokens in the context (that's prompts + generated)
56+
};
57+
4358
Instance& m_instance;
4459
InitParams m_params;
45-
unsigned maxTokens = 0;
46-
unsigned numKeep = 0;
47-
uint32_t gaIndex = 0; // number of grouped KV tokens (only used if params.gaFactor > 1)
48-
uint32_t numPast = 0; // number of tokens in the context (that's prompts + generated)
60+
State m_state;
4961
};
5062

5163
} // namespace ac::llama

0 commit comments

Comments
 (0)