Skip to content

Commit 18cafe8

Browse files
committed
feat: add tests for the session, ref #17
1 parent 324b9b5 commit 18cafe8

File tree

9 files changed

+108
-178
lines changed

9 files changed

+108
-178
lines changed

ac-local-plugin/code/LocalLlama.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ namespace ac::local {
2828

2929
namespace {
3030
class ChatSession {
31-
llama::Session m_session;
31+
llama::Session& m_session;
3232
const llama::Vocab& m_vocab;
3333
std::string m_userPrefix;
3434
std::string m_assistantPrefix;
@@ -149,7 +149,7 @@ class LlamaInstance final : public Instance {
149149
auto& prompt = params.prompt.value();
150150
const auto maxTokens = params.maxTokens.value();
151151

152-
auto s = m_instance.startSession({});
152+
auto& s = m_instance.startSession({});
153153

154154
auto promptTokens = m_instance.model().vocab().tokenize(prompt, true, true);
155155
s.setInitialPrompt(promptTokens);

code/ac/llama/Instance.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,11 @@ void Instance::warmup() {
117117
}
118118

119119
Session& Instance::startSession(const Session::InitParams params) {
120-
if (!m_session) {
121-
m_session.reset(new Session(*this, m_lctx.get(), params));
120+
if (m_session) {
121+
throw_ex{} << "Session is already started. Stop it to start a new one.";
122122
}
123123

124+
m_session.reset(new Session(*this, m_lctx.get(), params));
124125
return *m_session;
125126
}
126127

code/ac/llama/Session.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,16 @@ void Session::setInitialPrompt(std::span<const Token> initialPrompt) {
7979
}
8080

8181
doDecode(initialPrompt, Source::InitialPrompt);
82+
m_state.m_phase = State::Phase::Generating;
8283
}
8384

8485
void Session::pushPrompt(std::span<const Token> prompt) {
8586
if (m_state.m_phase != State::Phase::Generating) {
8687
throw_ex{} << "Session hasn't started yet";
8788
}
8889

90+
flushPendingState();
91+
8992
if (!prompt.empty()) {
9093
auto& sampler = m_instance.sampler();
9194
auto& model = m_instance.model();
@@ -108,10 +111,7 @@ Token Session::getToken() {
108111
throw_ex{} << "Session hasn't started yet";
109112
}
110113

111-
if (m_state.m_currToken != Token_Invalid) {
112-
// first yield, then decode, thus we don't decode if the session is aborted
113-
doDecode({&m_state.m_currToken, 1}, Source::Generated);
114-
}
114+
flushPendingState();
115115

116116
auto& sampler = m_instance.sampler();
117117
auto& vocab = m_instance.model().vocab();
@@ -131,6 +131,8 @@ std::vector<uint8_t> Session::getState() {
131131
throw_ex{} << "Session hasn't started yet";
132132
}
133133

134+
flushPendingState();
135+
134136
const auto size = llama_state_get_size(m_ctx);
135137
std::vector<uint8_t> state(size);
136138
if (llama_state_get_data(m_ctx, state.data(), size) != size) {
@@ -147,6 +149,8 @@ bool Session::setState(std::span<uint8_t> state) {
147149
if (llama_state_set_data(m_ctx, state.data(), state.size()) != state.size()) {
148150
throw_ex{} << "Failed to set state";
149151
}
152+
153+
m_state.m_phase = State::Phase::Generating;
150154
return true;
151155
}
152156

@@ -235,4 +239,11 @@ void Session::doDecode(std::span<const Token> tokens, Source src) {
235239

236240
}
237241

242+
void Session::flushPendingState() {
243+
if (m_state.m_currToken != Token_Invalid) {
244+
// first yield, then decode, thus we don't decode if the session is aborted
245+
doDecode({&m_state.m_currToken, 1}, Source::Generated);
246+
m_state.m_currToken = Token_Invalid;
247+
}
248+
}
238249
} // namespace ac::llama

code/ac/llama/Session.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,18 @@ class Session {
2626
bool infiniteContext = true;
2727
};
2828
Session(Instance& instance, llama_context* ctx, InitParams params);
29+
Session(const Session&) = delete;
30+
Session& operator=(const Session&) = delete;
31+
~Session() = default;
2932

33+
// initial functions to prepare the session
3034
void setInitialPrompt(std::span<const Token> prompt);
35+
bool setState(std::span<uint8_t> state);
3136

37+
// main functions to interact with the model
3238
void pushPrompt(std::span<const Token> prompt);
3339
Token getToken();
3440
std::vector<uint8_t> getState();
35-
bool setState(std::span<uint8_t> state);
3641
private:
3742
enum class Source {
3843
InitialPrompt,
@@ -41,6 +46,7 @@ class Session {
4146
};
4247

4348
void doDecode(std::span<const Token> tokens, Source src);
49+
void flushPendingState();
4450

4551
struct State {
4652
enum class Phase {

example/e-basic.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ int main() try {
6060
std::cout << "Prompt: " << prompt << "\n";
6161

6262
// start session
63-
auto session = instance.startSession({});
63+
auto& session = instance.startSession({});
6464
session.setInitialPrompt(model.vocab().tokenize(prompt, true, true));
6565

6666
// generate and print 100 tokens

example/e-gui.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ class UModel {
133133
ac::llama::Session::InitParams m_params;
134134
std::vector<ac::llama::Token> m_promptTokens;
135135
std::string m_text;
136-
ac::llama::Session m_session;
136+
ac::llama::Session& m_session;
137137
ac::llama::AntipromptManager m_antiprompt;
138138
uint32_t m_numTokens = 0;
139139
};

test/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ endmacro()
1111

1212
llama_test(Antiprompt)
1313
llama_test(ChatFormat)
14-
llama_test(Session)
1514

1615
add_doctest_lib_test(integration ac-llama
1716
SOURCES

test/t-Session.cpp

Lines changed: 0 additions & 157 deletions
This file was deleted.

0 commit comments

Comments
 (0)