Skip to content

Commit 324b9b5

Browse files
committed
fix: hide context from public API by passing it directly to the session, ref #17
1 parent c73ae32 commit 324b9b5

File tree

8 files changed

+54
-46
lines changed

8 files changed

+54
-46
lines changed

ac-local-plugin/code/LocalLlama.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class ChatSession {
4141
using Interface = ac::local::schema::LlamaCppInterface;
4242

4343
ChatSession(llama::Instance& instance, Interface::OpChatBegin::Params& params)
44-
: m_session(instance.newSession({}))
44+
: m_session(instance.startSession({}))
4545
, m_vocab(instance.model().vocab())
4646
{
4747
m_promptTokens = instance.model().vocab().tokenize(params.setup.value(), true, true);
@@ -149,7 +149,7 @@ class LlamaInstance final : public Instance {
149149
auto& prompt = params.prompt.value();
150150
const auto maxTokens = params.maxTokens.value();
151151

152-
auto s = m_instance.newSession({});
152+
auto s = m_instance.startSession({});
153153

154154
auto promptTokens = m_instance.model().vocab().tokenize(prompt, true, true);
155155
s.setInitialPrompt(promptTokens);

code/ac/llama/Instance.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,12 @@ void Instance::warmup() {
116116
llama_perf_context_reset(lctx);
117117
}
118118

119-
Session Instance::newSession(const Session::InitParams params) {
120-
// not a real await as we return suspend_always initially
121-
return Session(*this, params);
119+
Session& Instance::startSession(const Session::InitParams params) {
120+
if (!m_session) {
121+
m_session.reset(new Session(*this, m_lctx.get(), params));
122+
}
123+
124+
return *m_session;
122125
}
123126

124127
} // namespace ac::llama

code/ac/llama/Instance.hpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,17 @@ class AC_LLAMA_EXPORT Instance {
3434
void warmup();
3535

3636
// only one session per instance can be active at a time
37-
Session newSession(const Session::InitParams params);
37+
Session& startSession(const Session::InitParams params);
38+
void stopSession() noexcept { m_session.reset(); }
3839

3940
const Model& model() const noexcept { return m_model; }
40-
llama_context* ctx() const noexcept { return m_lctx.get(); }
4141
Sampler& sampler() noexcept { return m_sampler; }
4242

4343
private:
4444
Model& m_model;
4545
Sampler m_sampler;
4646
astl::c_unique_ptr<llama_context> m_lctx;
47-
48-
bool m_hasActiveSession = false;
47+
std::unique_ptr<Session> m_session;
4948
};
5049

5150
} // namespace ac::llama

code/ac/llama/Session.cpp

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,20 @@ llama_batch makeInputBatch(std::span<const Token> tokens) {
2121
}
2222
}
2323

24-
Session::Session(Instance& instance, InitParams params)
24+
Session::Session(Instance& instance, llama_context* ctx, InitParams params)
2525
: m_instance(instance)
26+
, m_ctx(ctx)
2627
, m_params(std::move(params))
2728
{
28-
auto lctx = m_instance.ctx();
2929
auto& sampler = m_instance.sampler();
3030

31-
llama_kv_cache_clear(lctx);
32-
llama_synchronize(lctx);
33-
llama_perf_context_reset(lctx);
31+
llama_kv_cache_clear(m_ctx);
32+
llama_synchronize(m_ctx);
33+
llama_perf_context_reset(m_ctx);
3434
sampler.reset();
3535
sampler.perfReset();
3636

37-
const auto ctxLen = llama_n_ctx(lctx);
37+
const auto ctxLen = llama_n_ctx(m_ctx);
3838
m_state.maxTokens = ctxLen - 4; // (#16)
3939
}
4040

@@ -45,8 +45,7 @@ void Session::setInitialPrompt(std::span<const Token> initialPrompt) {
4545

4646
Token initialToken; // used to reset the initial prompt to a single token
4747

48-
auto lctx = m_instance.ctx();
49-
const auto ctxLen = llama_n_ctx(lctx);
48+
const auto ctxLen = llama_n_ctx(m_ctx);
5049
const auto tokenBos = llama_token_bos(m_instance.model().lmodel());
5150
m_state.numKeep = std::min(uint32_t(initialPrompt.size()), m_state.maxTokens); // number of tokens to keep in the context in case we overflow
5251

@@ -70,7 +69,7 @@ void Session::setInitialPrompt(std::span<const Token> initialPrompt) {
7069

7170
if (m_instance.model().hasEncoder()) {
7271
auto batch = makeInputBatch(initialPrompt);
73-
auto res = llama_encode(lctx, batch);
72+
auto res = llama_encode(m_ctx, batch);
7473
if (res != 0) {
7574
throw_ex{} << "Failed to encode input";
7675
}
@@ -117,7 +116,7 @@ Token Session::getToken() {
117116
auto& sampler = m_instance.sampler();
118117
auto& vocab = m_instance.model().vocab();
119118

120-
m_state.m_currToken = sampler.sample(m_instance.ctx());
119+
m_state.m_currToken = sampler.sample(m_ctx);
121120

122121
if (vocab.isEog(m_state.m_currToken)) {
123122
// don't decode eog tokens in case the the interaction is continued
@@ -132,9 +131,9 @@ std::vector<uint8_t> Session::getState() {
132131
throw_ex{} << "Session hasn't started yet";
133132
}
134133

135-
const auto size = llama_state_get_size(m_instance.ctx());
134+
const auto size = llama_state_get_size(m_ctx);
136135
std::vector<uint8_t> state(size);
137-
if (llama_state_get_data(m_instance.ctx(), state.data(), size) != size) {
136+
if (llama_state_get_data(m_ctx, state.data(), size) != size) {
138137
throw_ex{} << "Failed to get state";
139138
}
140139
return state;
@@ -145,19 +144,13 @@ bool Session::setState(std::span<uint8_t> state) {
145144
throw_ex{} << "Session already started";
146145
}
147146

148-
if (llama_state_set_data(m_instance.ctx(), state.data(), state.size()) != state.size()) {
147+
if (llama_state_set_data(m_ctx, state.data(), state.size()) != state.size()) {
149148
throw_ex{} << "Failed to set state";
150149
}
151150
return true;
152151
}
153152

154153
void Session::doDecode(std::span<const Token> tokens, Source src) {
155-
// first try to expand the context if needed
156-
const auto gaFactor = m_params.gaFactor;
157-
auto lctx = m_instance.ctx();
158-
const auto ctxLen = llama_n_ctx(lctx);
159-
auto& sampler = m_instance.sampler();
160-
161154
// Ensure the input doesn't exceed the context size by truncating embd if necessary.
162155
if (tokens.size() > m_state.maxTokens) {
163156
const auto skipped = tokens.size() - m_state.maxTokens;
@@ -166,6 +159,10 @@ void Session::doDecode(std::span<const Token> tokens, Source src) {
166159
}
167160

168161
bool haveFullContextMitigation = false;
162+
const auto gaFactor = m_params.gaFactor;
163+
const auto ctxLen = llama_n_ctx(m_ctx);
164+
auto& sampler = m_instance.sampler();
165+
169166
if (gaFactor == 1) {
170167
// infinite text generation via context shifting
171168
// if we run out of context:
@@ -183,8 +180,8 @@ void Session::doDecode(std::span<const Token> tokens, Source src) {
183180
LLAMA_LOG(Debug, "Context is full. Swapping: past = ", m_state.numPast, ", numLeft: ", numLeft,
184181
", ctxLen: ", ctxLen, ", numKeep: ", m_state.numKeep, ", numDiscard: ", numDiscard);
185182

186-
llama_kv_cache_seq_rm(lctx, 0, m_state.numKeep, m_state.numKeep + numDiscard);
187-
llama_kv_cache_seq_add(lctx, 0, m_state.numKeep + numDiscard, m_state.numPast, -numDiscard);
183+
llama_kv_cache_seq_rm(m_ctx, 0, m_state.numKeep, m_state.numKeep + numDiscard);
184+
llama_kv_cache_seq_add(m_ctx, 0, m_state.numKeep + numDiscard, m_state.numPast, -numDiscard);
188185

189186
m_state.numPast -= numDiscard;
190187
haveFullContextMitigation = true;
@@ -201,9 +198,9 @@ void Session::doDecode(std::span<const Token> tokens, Source src) {
201198

202199
LLAMA_LOG(Debug, "Group attention shift: ib = ", ib, ", bd = ", bd, ", dd = ", dd);
203200

204-
llama_kv_cache_seq_add(lctx, 0, m_state.gaIndex, m_state.numPast, ib * bd);
205-
llama_kv_cache_seq_div(lctx, 0, m_state.gaIndex + ib * bd, m_state.gaIndex + ib * bd + gaWidth, gaFactor);
206-
llama_kv_cache_seq_add(lctx, 0, m_state.gaIndex + ib * bd + gaWidth, m_state.numPast + ib * bd, dd);
201+
llama_kv_cache_seq_add(m_ctx, 0, m_state.gaIndex, m_state.numPast, ib * bd);
202+
llama_kv_cache_seq_div(m_ctx, 0, m_state.gaIndex + ib * bd, m_state.gaIndex + ib * bd + gaWidth, gaFactor);
203+
llama_kv_cache_seq_add(m_ctx, 0, m_state.gaIndex + ib * bd + gaWidth, m_state.numPast + ib * bd, dd);
207204

208205
m_state.numPast -= bd;
209206

@@ -223,14 +220,14 @@ void Session::doDecode(std::span<const Token> tokens, Source src) {
223220
}
224221

225222
// decode
226-
const auto batchSize = llama_n_batch(lctx);
223+
const auto batchSize = llama_n_batch(m_ctx);
227224

228225
// decode with batches of batchSize
229226
while (!tokens.empty()) {
230227
auto batchTokens = tokens.size() > batchSize ? tokens.first(batchSize) : tokens;
231228
tokens = tokens.subspan(batchTokens.size());
232229
auto batch = makeInputBatch(batchTokens);
233-
if (llama_decode(lctx, batch) != 0) {
230+
if (llama_decode(m_ctx, batch) != 0) {
234231
throw_ex{} << "Failed to decode tokens";
235232
}
236233
m_state.numPast += uint32_t(batchTokens.size());

code/ac/llama/Session.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
#include <vector>
1111
#include <cassert>
1212

13+
struct llama_context;
14+
1315
namespace ac::llama {
1416
class Instance;
1517

@@ -23,7 +25,7 @@ class Session {
2325
// only used if gaFactor == 1
2426
bool infiniteContext = true;
2527
};
26-
Session(Instance& instance, InitParams params);
28+
Session(Instance& instance, llama_context* ctx, InitParams params);
2729

2830
void setInitialPrompt(std::span<const Token> prompt);
2931

@@ -56,6 +58,7 @@ class Session {
5658
};
5759

5860
Instance& m_instance;
61+
llama_context* m_ctx;
5962
InitParams m_params;
6063
State m_state;
6164
};

example/e-basic.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ int main() try {
6060
std::cout << "Prompt: " << prompt << "\n";
6161

6262
// start session
63-
auto session = instance.newSession({});
63+
auto session = instance.startSession({});
6464
session.setInitialPrompt(model.vocab().tokenize(prompt, true, true));
6565

6666
// generate and print 100 tokens

example/e-gui.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,11 @@ class UModel {
7575
class Session {
7676
public:
7777
Session(ac::llama::Instance& instance, std::string_view prompt, std::vector<std::string> antiprompts, ac::llama::Session::InitParams params)
78-
: m_vocab(instance.model().vocab())
78+
: m_instance(instance)
79+
, m_vocab(instance.model().vocab())
7980
, m_params(std::move(params))
8081
, m_text(std::move(prompt))
81-
, m_session(instance.newSession(m_params))
82+
, m_session(instance.startSession(m_params))
8283
{
8384
m_promptTokens = m_vocab.tokenize(m_text, true, true);
8485
m_session.setInitialPrompt(m_promptTokens);
@@ -87,6 +88,10 @@ class UModel {
8788
}
8889
}
8990

91+
~Session() {
92+
m_instance.stopSession();
93+
}
94+
9095
const std::string& text() const { return m_text; }
9196
const ac::llama::Session::InitParams& params() const { return m_params; }
9297

@@ -123,6 +128,7 @@ class UModel {
123128
}
124129

125130
private:
131+
ac::llama::Instance& m_instance;
126132
const ac::llama::Vocab& m_vocab;
127133
ac::llama::Session::InitParams m_params;
128134
std::vector<ac::llama::Token> m_promptTokens;

test/t-integration.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ TEST_CASE("inference") {
6464
std::vector<ac::llama::Token> tokens;
6565

6666
// choose a very, very suggestive prompt and hope that all architectures will agree
67-
auto s = inst.newSession({});
67+
auto s = inst.startSession({});
6868
tokens = model.vocab().tokenize("President George W.", true, true);
6969
s.setInitialPrompt(tokens);
7070
{
@@ -107,7 +107,7 @@ TEST_CASE("session states") {
107107
ac::llama::Instance inst(model, {});
108108
inst.addControlVector(ctrlVector);
109109
inst.warmup(); // should be safe
110-
auto s = inst.newSession({});
110+
auto s = inst.startSession({});
111111
std::vector<ac::llama::Token> tokens = model.vocab().tokenize("My car's fuel consumption is", true, true);
112112
s.setInitialPrompt(tokens);
113113
std::string text;
@@ -124,7 +124,7 @@ TEST_CASE("session states") {
124124
ac::llama::Instance inst(model, {});
125125
inst.addControlVector(ctrlVector);
126126
inst.warmup(); // should be safe
127-
auto s = inst.newSession({});
127+
auto s = inst.startSession({});
128128
std::vector<ac::llama::Token> tokens = model.vocab().tokenize("My car's fuel consumption is", true, true);
129129
s.setInitialPrompt(tokens);
130130
std::string text;
@@ -168,7 +168,7 @@ TEST_CASE("control_vector") {
168168
{
169169
// session 1
170170

171-
auto s = inst.newSession({});
171+
auto s = inst.startSession({});
172172
auto tokens = model.vocab().tokenize(prompt, true, true);
173173
s.setInitialPrompt(tokens);
174174

@@ -196,7 +196,7 @@ TEST_CASE("control_vector") {
196196
// test restoring the intial state
197197
// since the sampler is in the initial state we should get the same string
198198
{
199-
auto s = inst.newSession({});
199+
auto s = inst.startSession({});
200200
s.setState(initialState);
201201
std::string restoredStr;
202202

@@ -218,7 +218,7 @@ TEST_CASE("control_vector") {
218218
//restores session 1
219219
std::string restoredStr;
220220
{
221-
auto s = inst.newSession({});
221+
auto s = inst.startSession({});
222222
s.setState(sessionMiddleState);
223223

224224
for (size_t i = 0; i < nPredict / 2; i++) {
@@ -235,7 +235,7 @@ TEST_CASE("control_vector") {
235235
//restores session 2
236236
std::string restoredStr2;
237237
{
238-
auto s = inst.newSession({});
238+
auto s = inst.startSession({});
239239
s.setState(sessionMiddleState);
240240

241241
for (size_t i = 0; i < nPredict / 2; i++) {

0 commit comments

Comments
 (0)