Skip to content

Commit 36696f3

Browse files
committed
Add preliminary chat history functionality
1 parent abe0d1d commit 36696f3

File tree

20 files changed

+633
-64
lines changed

20 files changed

+633
-64
lines changed

llamafile/BUILD.mk

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ o/$(MODE)/llamafile: \
6363
o/$(MODE)/llamafile/parse_cidr_test.runs \
6464
o/$(MODE)/llamafile/pool_cancel_test.runs \
6565
o/$(MODE)/llamafile/pool_test.runs \
66+
o/$(MODE)/llamafile/json_test.runs \
6667
o/$(MODE)/llamafile/thread_test.runs \
6768
o/$(MODE)/llamafile/vmathf_test.runs \
6869

@@ -156,6 +157,12 @@ o/$(MODE)/llamafile/tinyblas_cpu_sgemm_arm82.o: \
156157
################################################################################
157158
# testing
158159

160+
o/$(MODE)/llamafile/json_test: \
161+
o/$(MODE)/llamafile/json_test.o \
162+
o/$(MODE)/llamafile/json.o \
163+
o/$(MODE)/llamafile/hextoint.o \
164+
o/$(MODE)/double-conversion/double-conversion.a \
165+
159166
o/$(MODE)/llamafile/vmathf_test: \
160167
o/$(MODE)/llamafile/vmathf_test.o \
161168
o/$(MODE)/llama.cpp/llama.cpp.a \

llamafile/db.cpp

Lines changed: 295 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,24 @@
1616
// limitations under the License.
1717

1818
#include "db.h"
19+
#include "llamafile/json.h"
20+
#include "llamafile/llamafile.h"
21+
#include "third_party/sqlite/sqlite3.h"
22+
#include <pthread.h>
1923
#include <stdio.h>
24+
#include <stdlib.h>
2025
#include <string>
2126

2227
__static_yoink("llamafile/schema.sql");
2328

2429
#define SCHEMA_VERSION 1
2530

26-
namespace llamafile {
31+
namespace lf {
2732
namespace db {
2833

29-
static bool table_exists(sqlite3* db, const char* table_name) {
30-
const char* query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?;";
31-
sqlite3_stmt* stmt;
34+
static bool table_exists(sqlite3 *db, const char *table_name) {
35+
const char *query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?;";
36+
sqlite3_stmt *stmt;
3237
if (sqlite3_prepare_v2(db, query, -1, &stmt, nullptr) != SQLITE_OK) {
3338
return false;
3439
}
@@ -41,16 +46,16 @@ static bool table_exists(sqlite3* db, const char* table_name) {
4146
return exists;
4247
}
4348

44-
static bool init_schema(sqlite3* db) {
45-
FILE* f = fopen("/zip/llamafile/schema.sql", "r");
49+
static bool init_schema(sqlite3 *db) {
50+
FILE *f = fopen("/zip/llamafile/schema.sql", "r");
4651
if (!f)
4752
return false;
4853
std::string schema;
4954
int c;
5055
while ((c = fgetc(f)) != EOF)
5156
schema += c;
5257
fclose(f);
53-
char* errmsg = nullptr;
58+
char *errmsg = nullptr;
5459
int rc = sqlite3_exec(db, schema.c_str(), nullptr, nullptr, &errmsg);
5560
if (rc != SQLITE_OK) {
5661
if (errmsg) {
@@ -62,37 +67,310 @@ static bool init_schema(sqlite3* db) {
6267
return true;
6368
}
6469

65-
sqlite3* open(const char* path) {
66-
sqlite3* db;
67-
int rc = sqlite3_open(path, &db);
70+
static sqlite3 *open_impl() {
71+
std::string path;
72+
if (FLAG_db) {
73+
path = FLAG_db;
74+
} else {
75+
const char *home = getenv("HOME");
76+
if (home) {
77+
path = std::string(home) + "/.llamafile/llamafile.sqlite3";
78+
} else {
79+
path = "llamafile.sqlite3";
80+
}
81+
}
82+
sqlite3 *db;
83+
int rc = sqlite3_open(path.c_str(), &db);
6884
if (rc) {
69-
fprintf(stderr, "%s: can't open database: %s\n", path, sqlite3_errmsg(db));
85+
fprintf(stderr, "%s: can't open database: %s\n", path.c_str(), sqlite3_errmsg(db));
7086
return nullptr;
7187
}
72-
char* errmsg = nullptr;
88+
char *errmsg = nullptr;
7389
if (sqlite3_exec(db, "PRAGMA journal_mode=WAL;", nullptr, nullptr, &errmsg) != SQLITE_OK) {
74-
fprintf(stderr, "Failed to set journal mode to WAL: %s\n", errmsg);
90+
fprintf(stderr, "%s: failed to set journal mode to wal: %s\n", path.c_str(), errmsg);
7591
sqlite3_free(errmsg);
7692
sqlite3_close(db);
7793
return nullptr;
7894
}
7995
if (sqlite3_exec(db, "PRAGMA synchronous=NORMAL;", nullptr, nullptr, &errmsg) != SQLITE_OK) {
80-
fprintf(stderr, "Failed to set synchronous to NORMAL: %s\n", errmsg);
96+
fprintf(stderr, "%s: failed to set synchronous to normal: %s\n", path.c_str(), errmsg);
8197
sqlite3_free(errmsg);
8298
sqlite3_close(db);
8399
return nullptr;
84100
}
85101
if (!table_exists(db, "metadata") && !init_schema(db)) {
86-
fprintf(stderr, "%s: failed to initialize database schema\n", path);
102+
fprintf(stderr, "%s: failed to initialize database schema\n", path.c_str());
87103
sqlite3_close(db);
88104
return nullptr;
89105
}
90106
return db;
91107
}
92108

93-
void close(sqlite3* db) {
109+
sqlite3 *open() {
110+
int cs;
111+
pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, &cs);
112+
sqlite3 *res = open_impl();
113+
pthread_setcancelstate(cs, 0);
114+
return res;
115+
}
116+
117+
void close(sqlite3 *db) {
118+
int cs;
119+
pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, &cs);
94120
sqlite3_close(db);
121+
pthread_setcancelstate(cs, 0);
122+
}
123+
124+
static int64_t add_chat_impl(sqlite3 *db, const std::string &model, const std::string &title) {
125+
const char *query = "INSERT INTO chats (model, title) VALUES (?, ?);";
126+
sqlite3_stmt *stmt;
127+
if (sqlite3_prepare_v2(db, query, -1, &stmt, nullptr) != SQLITE_OK) {
128+
return -1;
129+
}
130+
if (sqlite3_bind_text(stmt, 1, model.data(), model.size(), SQLITE_STATIC) != SQLITE_OK ||
131+
sqlite3_bind_text(stmt, 2, title.data(), title.size(), SQLITE_STATIC) != SQLITE_OK) {
132+
sqlite3_finalize(stmt);
133+
return -1;
134+
}
135+
if (sqlite3_step(stmt) != SQLITE_DONE) {
136+
sqlite3_finalize(stmt);
137+
return -1;
138+
}
139+
sqlite3_finalize(stmt);
140+
return sqlite3_last_insert_rowid(db);
141+
}
142+
143+
int64_t add_chat(sqlite3 *db, const std::string &model, const std::string &title) {
144+
int cs;
145+
pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, &cs);
146+
int64_t res = add_chat_impl(db, model, title);
147+
pthread_setcancelstate(cs, 0);
148+
return res;
149+
}
150+
151+
static int64_t add_message_impl(sqlite3 *db, int64_t chat_id, const std::string &role,
152+
const std::string &content, double temperature, double top_p,
153+
double presence_penalty, double frequency_penalty) {
154+
const char *query = "INSERT INTO messages (chat_id, role, content, temperature, "
155+
"top_p, presence_penalty, frequency_penalty) "
156+
"VALUES (?, ?, ?, ?, ?, ?, ?);";
157+
sqlite3_stmt *stmt;
158+
if (sqlite3_prepare_v2(db, query, -1, &stmt, nullptr) != SQLITE_OK) {
159+
return -1;
160+
}
161+
if (sqlite3_bind_int64(stmt, 1, chat_id) != SQLITE_OK ||
162+
sqlite3_bind_text(stmt, 2, role.data(), role.size(), SQLITE_STATIC) != SQLITE_OK ||
163+
sqlite3_bind_text(stmt, 3, content.data(), content.size(), SQLITE_STATIC) != SQLITE_OK ||
164+
sqlite3_bind_double(stmt, 4, temperature) != SQLITE_OK ||
165+
sqlite3_bind_double(stmt, 5, top_p) != SQLITE_OK ||
166+
sqlite3_bind_double(stmt, 6, presence_penalty) != SQLITE_OK ||
167+
sqlite3_bind_double(stmt, 7, frequency_penalty) != SQLITE_OK) {
168+
sqlite3_finalize(stmt);
169+
return -1;
170+
}
171+
if (sqlite3_step(stmt) != SQLITE_DONE) {
172+
sqlite3_finalize(stmt);
173+
return -1;
174+
}
175+
sqlite3_finalize(stmt);
176+
return sqlite3_last_insert_rowid(db);
177+
}
178+
179+
int64_t add_message(sqlite3 *db, int64_t chat_id, const std::string &role,
180+
const std::string &content, double temperature, double top_p,
181+
double presence_penalty, double frequency_penalty) {
182+
int cs;
183+
pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, &cs);
184+
int64_t res = add_message_impl(db, chat_id, role, content, temperature, top_p, presence_penalty,
185+
frequency_penalty);
186+
pthread_setcancelstate(cs, 0);
187+
return res;
188+
}
189+
190+
static bool update_title_impl(sqlite3 *db, int64_t chat_id, const std::string &title) {
191+
const char *query = "UPDATE chats SET title = ? WHERE id = ?;";
192+
sqlite3_stmt *stmt;
193+
if (sqlite3_prepare_v2(db, query, -1, &stmt, nullptr) != SQLITE_OK) {
194+
return false;
195+
}
196+
if (sqlite3_bind_text(stmt, 1, title.data(), title.size(), SQLITE_STATIC) != SQLITE_OK ||
197+
sqlite3_bind_int64(stmt, 2, chat_id) != SQLITE_OK) {
198+
sqlite3_finalize(stmt);
199+
return false;
200+
}
201+
bool success = sqlite3_step(stmt) == SQLITE_DONE;
202+
sqlite3_finalize(stmt);
203+
return success;
204+
}
205+
206+
bool update_title(sqlite3 *db, int64_t chat_id, const std::string &title) {
207+
int cs;
208+
pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, &cs);
209+
bool res = update_title_impl(db, chat_id, title);
210+
pthread_setcancelstate(cs, 0);
211+
return res;
212+
}
213+
214+
static bool delete_message_impl(sqlite3 *db, int64_t message_id) {
215+
const char *query = "DELETE FROM messages WHERE id = ?;";
216+
sqlite3_stmt *stmt;
217+
if (sqlite3_prepare_v2(db, query, -1, &stmt, nullptr) != SQLITE_OK) {
218+
return false;
219+
}
220+
if (sqlite3_bind_int64(stmt, 1, message_id) != SQLITE_OK) {
221+
sqlite3_finalize(stmt);
222+
return false;
223+
}
224+
bool success = sqlite3_step(stmt) == SQLITE_DONE;
225+
sqlite3_finalize(stmt);
226+
return success;
227+
}
228+
229+
bool delete_message(sqlite3 *db, int64_t message_id) {
230+
int cs;
231+
pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, &cs);
232+
bool res = delete_message_impl(db, message_id);
233+
pthread_setcancelstate(cs, 0);
234+
return res;
235+
}
236+
237+
static jt::Json get_chats_impl(sqlite3 *db) {
238+
const char *query = "SELECT id, created_at, model, title FROM chats ORDER BY created_at DESC;";
239+
sqlite3_stmt *stmt;
240+
jt::Json result;
241+
result.setArray();
242+
if (sqlite3_prepare_v2(db, query, -1, &stmt, nullptr) != SQLITE_OK) {
243+
return result;
244+
}
245+
while (sqlite3_step(stmt) == SQLITE_ROW) {
246+
jt::Json chat;
247+
chat.setObject();
248+
chat["id"] = sqlite3_column_int64(stmt, 0);
249+
chat["created_at"] = reinterpret_cast<const char *>(sqlite3_column_text(stmt, 1));
250+
chat["model"] = reinterpret_cast<const char *>(sqlite3_column_text(stmt, 2));
251+
chat["title"] = reinterpret_cast<const char *>(sqlite3_column_text(stmt, 3));
252+
result.getArray().push_back(std::move(chat));
253+
}
254+
sqlite3_finalize(stmt);
255+
return result;
256+
}
257+
258+
jt::Json get_chats(sqlite3 *db) {
259+
int cs;
260+
pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, &cs);
261+
jt::Json res = get_chats_impl(db);
262+
pthread_setcancelstate(cs, 0);
263+
return res;
264+
}
265+
266+
static jt::Json get_messages_impl(sqlite3 *db, int64_t chat_id) {
267+
const char *query = "SELECT id, created_at, role, content, temperature, top_p, "
268+
"presence_penalty, frequency_penalty "
269+
"FROM messages "
270+
"WHERE chat_id = ? "
271+
"ORDER BY created_at DESC;";
272+
sqlite3_stmt *stmt;
273+
jt::Json result;
274+
result.setArray();
275+
if (sqlite3_prepare_v2(db, query, -1, &stmt, nullptr) != SQLITE_OK) {
276+
return result;
277+
}
278+
if (sqlite3_bind_int64(stmt, 1, chat_id) != SQLITE_OK) {
279+
sqlite3_finalize(stmt);
280+
return result;
281+
}
282+
while (sqlite3_step(stmt) == SQLITE_ROW) {
283+
jt::Json msg;
284+
msg.setObject();
285+
msg["id"] = sqlite3_column_int64(stmt, 0);
286+
msg["created_at"] = reinterpret_cast<const char *>(sqlite3_column_text(stmt, 1));
287+
msg["role"] = reinterpret_cast<const char *>(sqlite3_column_text(stmt, 2));
288+
msg["content"] = reinterpret_cast<const char *>(sqlite3_column_text(stmt, 3));
289+
msg["temperature"] = sqlite3_column_double(stmt, 4);
290+
msg["top_p"] = sqlite3_column_double(stmt, 5);
291+
msg["presence_penalty"] = sqlite3_column_double(stmt, 6);
292+
msg["frequency_penalty"] = sqlite3_column_double(stmt, 7);
293+
result.getArray().push_back(std::move(msg));
294+
}
295+
sqlite3_finalize(stmt);
296+
return result;
297+
}
298+
299+
jt::Json get_messages(sqlite3 *db, int64_t chat_id) {
300+
int cs;
301+
pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, &cs);
302+
jt::Json res = get_messages_impl(db, chat_id);
303+
pthread_setcancelstate(cs, 0);
304+
return res;
305+
}
306+
307+
static jt::Json get_chat_impl(sqlite3 *db, int64_t chat_id) {
308+
const char *query = "SELECT id, created_at, model, title FROM chats WHERE id = ?;";
309+
sqlite3_stmt *stmt;
310+
jt::Json result;
311+
result.setObject();
312+
if (sqlite3_prepare_v2(db, query, -1, &stmt, nullptr) != SQLITE_OK) {
313+
return result;
314+
}
315+
if (sqlite3_bind_int64(stmt, 1, chat_id) != SQLITE_OK) {
316+
sqlite3_finalize(stmt);
317+
return result;
318+
}
319+
if (sqlite3_step(stmt) == SQLITE_ROW) {
320+
result["id"] = sqlite3_column_int64(stmt, 0);
321+
result["created_at"] = reinterpret_cast<const char *>(sqlite3_column_text(stmt, 1));
322+
result["model"] = reinterpret_cast<const char *>(sqlite3_column_text(stmt, 2));
323+
result["title"] = reinterpret_cast<const char *>(sqlite3_column_text(stmt, 3));
324+
}
325+
sqlite3_finalize(stmt);
326+
return result;
327+
}
328+
329+
jt::Json get_chat(sqlite3 *db, int64_t chat_id) {
330+
int cs;
331+
pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, &cs);
332+
jt::Json res = get_chat_impl(db, chat_id);
333+
pthread_setcancelstate(cs, 0);
334+
return res;
335+
}
336+
337+
static jt::Json get_message_impl(sqlite3 *db, int64_t message_id) {
338+
const char *query = "SELECT id, created_at, chat_id, role, content, temperature, top_p, "
339+
"presence_penalty, frequency_penalty "
340+
"FROM messages WHERE id = ?"
341+
"ORDER BY created_at ASC;";
342+
sqlite3_stmt *stmt;
343+
jt::Json result;
344+
result.setObject();
345+
if (sqlite3_prepare_v2(db, query, -1, &stmt, nullptr) != SQLITE_OK) {
346+
return result;
347+
}
348+
if (sqlite3_bind_int64(stmt, 1, message_id) != SQLITE_OK) {
349+
sqlite3_finalize(stmt);
350+
return result;
351+
}
352+
if (sqlite3_step(stmt) == SQLITE_ROW) {
353+
result["id"] = sqlite3_column_int64(stmt, 0);
354+
result["created_at"] = reinterpret_cast<const char *>(sqlite3_column_text(stmt, 1));
355+
result["chat_id"] = sqlite3_column_int64(stmt, 2);
356+
result["role"] = reinterpret_cast<const char *>(sqlite3_column_text(stmt, 3));
357+
result["content"] = reinterpret_cast<const char *>(sqlite3_column_text(stmt, 4));
358+
result["temperature"] = sqlite3_column_double(stmt, 5);
359+
result["top_p"] = sqlite3_column_double(stmt, 6);
360+
result["presence_penalty"] = sqlite3_column_double(stmt, 7);
361+
result["frequency_penalty"] = sqlite3_column_double(stmt, 8);
362+
}
363+
sqlite3_finalize(stmt);
364+
return result;
365+
}
366+
367+
jt::Json get_message(sqlite3 *db, int64_t message_id) {
368+
int cs;
369+
pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, &cs);
370+
jt::Json res = get_message_impl(db, message_id);
371+
pthread_setcancelstate(cs, 0);
372+
return res;
95373
}
96374

97375
} // namespace db
98-
} // namespace llamafile
376+
} // namespace lf

0 commit comments

Comments
 (0)