1616// limitations under the License.
1717
1818#include " db.h"
19+ #include " llamafile/json.h"
20+ #include " llamafile/llamafile.h"
21+ #include " third_party/sqlite/sqlite3.h"
22+ #include < pthread.h>
1923#include < stdio.h>
24+ #include < stdlib.h>
2025#include < string>
2126
2227__static_yoink (" llamafile/schema.sql" );
2328
2429#define SCHEMA_VERSION 1
2530
26- namespace llamafile {
31+ namespace lf {
2732namespace db {
2833
29- static bool table_exists (sqlite3* db, const char * table_name) {
30- const char * query = " SELECT name FROM sqlite_master WHERE type='table' AND name=?;" ;
31- sqlite3_stmt* stmt;
34+ static bool table_exists (sqlite3 * db, const char * table_name) {
35+ const char * query = " SELECT name FROM sqlite_master WHERE type='table' AND name=?;" ;
36+ sqlite3_stmt * stmt;
3237 if (sqlite3_prepare_v2 (db, query, -1 , &stmt, nullptr ) != SQLITE_OK) {
3338 return false ;
3439 }
@@ -41,16 +46,16 @@ static bool table_exists(sqlite3* db, const char* table_name) {
4146 return exists;
4247}
4348
44- static bool init_schema (sqlite3* db) {
45- FILE* f = fopen (" /zip/llamafile/schema.sql" , " r" );
49+ static bool init_schema (sqlite3 * db) {
50+ FILE * f = fopen (" /zip/llamafile/schema.sql" , " r" );
4651 if (!f)
4752 return false ;
4853 std::string schema;
4954 int c;
5055 while ((c = fgetc (f)) != EOF)
5156 schema += c;
5257 fclose (f);
53- char * errmsg = nullptr ;
58+ char * errmsg = nullptr ;
5459 int rc = sqlite3_exec (db, schema.c_str (), nullptr , nullptr , &errmsg);
5560 if (rc != SQLITE_OK) {
5661 if (errmsg) {
@@ -62,37 +67,310 @@ static bool init_schema(sqlite3* db) {
6267 return true ;
6368}
6469
65- sqlite3* open (const char * path) {
66- sqlite3* db;
67- int rc = sqlite3_open (path, &db);
70+ static sqlite3 *open_impl () {
71+ std::string path;
72+ if (FLAG_db) {
73+ path = FLAG_db;
74+ } else {
75+ const char *home = getenv (" HOME" );
76+ if (home) {
77+ path = std::string (home) + " /.llamafile/llamafile.sqlite3" ;
78+ } else {
79+ path = " llamafile.sqlite3" ;
80+ }
81+ }
82+ sqlite3 *db;
83+ int rc = sqlite3_open (path.c_str (), &db);
6884 if (rc) {
69- fprintf (stderr, " %s: can't open database: %s\n " , path, sqlite3_errmsg (db));
85+ fprintf (stderr, " %s: can't open database: %s\n " , path. c_str () , sqlite3_errmsg (db));
7086 return nullptr ;
7187 }
72- char * errmsg = nullptr ;
88+ char * errmsg = nullptr ;
7389 if (sqlite3_exec (db, " PRAGMA journal_mode=WAL;" , nullptr , nullptr , &errmsg) != SQLITE_OK) {
74- fprintf (stderr, " Failed to set journal mode to WAL : %s\n " , errmsg);
90+ fprintf (stderr, " %s: failed to set journal mode to wal : %s\n " , path. c_str () , errmsg);
7591 sqlite3_free (errmsg);
7692 sqlite3_close (db);
7793 return nullptr ;
7894 }
7995 if (sqlite3_exec (db, " PRAGMA synchronous=NORMAL;" , nullptr , nullptr , &errmsg) != SQLITE_OK) {
80- fprintf (stderr, " Failed to set synchronous to NORMAL : %s\n " , errmsg);
96+ fprintf (stderr, " %s: failed to set synchronous to normal : %s\n " , path. c_str () , errmsg);
8197 sqlite3_free (errmsg);
8298 sqlite3_close (db);
8399 return nullptr ;
84100 }
85101 if (!table_exists (db, " metadata" ) && !init_schema (db)) {
86- fprintf (stderr, " %s: failed to initialize database schema\n " , path);
102+ fprintf (stderr, " %s: failed to initialize database schema\n " , path. c_str () );
87103 sqlite3_close (db);
88104 return nullptr ;
89105 }
90106 return db;
91107}
92108
93- void close (sqlite3* db) {
109+ sqlite3 *open () {
110+ int cs;
111+ pthread_setcancelstate (PTHREAD_CANCEL_DISABLE, &cs);
112+ sqlite3 *res = open_impl ();
113+ pthread_setcancelstate (cs, 0 );
114+ return res;
115+ }
116+
117+ void close (sqlite3 *db) {
118+ int cs;
119+ pthread_setcancelstate (PTHREAD_CANCEL_DISABLE, &cs);
94120 sqlite3_close (db);
121+ pthread_setcancelstate (cs, 0 );
122+ }
123+
124+ static int64_t add_chat_impl (sqlite3 *db, const std::string &model, const std::string &title) {
125+ const char *query = " INSERT INTO chats (model, title) VALUES (?, ?);" ;
126+ sqlite3_stmt *stmt;
127+ if (sqlite3_prepare_v2 (db, query, -1 , &stmt, nullptr ) != SQLITE_OK) {
128+ return -1 ;
129+ }
130+ if (sqlite3_bind_text (stmt, 1 , model.data (), model.size (), SQLITE_STATIC) != SQLITE_OK ||
131+ sqlite3_bind_text (stmt, 2 , title.data (), title.size (), SQLITE_STATIC) != SQLITE_OK) {
132+ sqlite3_finalize (stmt);
133+ return -1 ;
134+ }
135+ if (sqlite3_step (stmt) != SQLITE_DONE) {
136+ sqlite3_finalize (stmt);
137+ return -1 ;
138+ }
139+ sqlite3_finalize (stmt);
140+ return sqlite3_last_insert_rowid (db);
141+ }
142+
143+ int64_t add_chat (sqlite3 *db, const std::string &model, const std::string &title) {
144+ int cs;
145+ pthread_setcancelstate (PTHREAD_CANCEL_DISABLE, &cs);
146+ int64_t res = add_chat_impl (db, model, title);
147+ pthread_setcancelstate (cs, 0 );
148+ return res;
149+ }
150+
151+ static int64_t add_message_impl (sqlite3 *db, int64_t chat_id, const std::string &role,
152+ const std::string &content, double temperature, double top_p,
153+ double presence_penalty, double frequency_penalty) {
154+ const char *query = " INSERT INTO messages (chat_id, role, content, temperature, "
155+ " top_p, presence_penalty, frequency_penalty) "
156+ " VALUES (?, ?, ?, ?, ?, ?, ?);" ;
157+ sqlite3_stmt *stmt;
158+ if (sqlite3_prepare_v2 (db, query, -1 , &stmt, nullptr ) != SQLITE_OK) {
159+ return -1 ;
160+ }
161+ if (sqlite3_bind_int64 (stmt, 1 , chat_id) != SQLITE_OK ||
162+ sqlite3_bind_text (stmt, 2 , role.data (), role.size (), SQLITE_STATIC) != SQLITE_OK ||
163+ sqlite3_bind_text (stmt, 3 , content.data (), content.size (), SQLITE_STATIC) != SQLITE_OK ||
164+ sqlite3_bind_double (stmt, 4 , temperature) != SQLITE_OK ||
165+ sqlite3_bind_double (stmt, 5 , top_p) != SQLITE_OK ||
166+ sqlite3_bind_double (stmt, 6 , presence_penalty) != SQLITE_OK ||
167+ sqlite3_bind_double (stmt, 7 , frequency_penalty) != SQLITE_OK) {
168+ sqlite3_finalize (stmt);
169+ return -1 ;
170+ }
171+ if (sqlite3_step (stmt) != SQLITE_DONE) {
172+ sqlite3_finalize (stmt);
173+ return -1 ;
174+ }
175+ sqlite3_finalize (stmt);
176+ return sqlite3_last_insert_rowid (db);
177+ }
178+
179+ int64_t add_message (sqlite3 *db, int64_t chat_id, const std::string &role,
180+ const std::string &content, double temperature, double top_p,
181+ double presence_penalty, double frequency_penalty) {
182+ int cs;
183+ pthread_setcancelstate (PTHREAD_CANCEL_DISABLE, &cs);
184+ int64_t res = add_message_impl (db, chat_id, role, content, temperature, top_p, presence_penalty,
185+ frequency_penalty);
186+ pthread_setcancelstate (cs, 0 );
187+ return res;
188+ }
189+
190+ static bool update_title_impl (sqlite3 *db, int64_t chat_id, const std::string &title) {
191+ const char *query = " UPDATE chats SET title = ? WHERE id = ?;" ;
192+ sqlite3_stmt *stmt;
193+ if (sqlite3_prepare_v2 (db, query, -1 , &stmt, nullptr ) != SQLITE_OK) {
194+ return false ;
195+ }
196+ if (sqlite3_bind_text (stmt, 1 , title.data (), title.size (), SQLITE_STATIC) != SQLITE_OK ||
197+ sqlite3_bind_int64 (stmt, 2 , chat_id) != SQLITE_OK) {
198+ sqlite3_finalize (stmt);
199+ return false ;
200+ }
201+ bool success = sqlite3_step (stmt) == SQLITE_DONE;
202+ sqlite3_finalize (stmt);
203+ return success;
204+ }
205+
206+ bool update_title (sqlite3 *db, int64_t chat_id, const std::string &title) {
207+ int cs;
208+ pthread_setcancelstate (PTHREAD_CANCEL_DISABLE, &cs);
209+ bool res = update_title_impl (db, chat_id, title);
210+ pthread_setcancelstate (cs, 0 );
211+ return res;
212+ }
213+
214+ static bool delete_message_impl (sqlite3 *db, int64_t message_id) {
215+ const char *query = " DELETE FROM messages WHERE id = ?;" ;
216+ sqlite3_stmt *stmt;
217+ if (sqlite3_prepare_v2 (db, query, -1 , &stmt, nullptr ) != SQLITE_OK) {
218+ return false ;
219+ }
220+ if (sqlite3_bind_int64 (stmt, 1 , message_id) != SQLITE_OK) {
221+ sqlite3_finalize (stmt);
222+ return false ;
223+ }
224+ bool success = sqlite3_step (stmt) == SQLITE_DONE;
225+ sqlite3_finalize (stmt);
226+ return success;
227+ }
228+
229+ bool delete_message (sqlite3 *db, int64_t message_id) {
230+ int cs;
231+ pthread_setcancelstate (PTHREAD_CANCEL_DISABLE, &cs);
232+ bool res = delete_message_impl (db, message_id);
233+ pthread_setcancelstate (cs, 0 );
234+ return res;
235+ }
236+
237+ static jt::Json get_chats_impl (sqlite3 *db) {
238+ const char *query = " SELECT id, created_at, model, title FROM chats ORDER BY created_at DESC;" ;
239+ sqlite3_stmt *stmt;
240+ jt::Json result;
241+ result.setArray ();
242+ if (sqlite3_prepare_v2 (db, query, -1 , &stmt, nullptr ) != SQLITE_OK) {
243+ return result;
244+ }
245+ while (sqlite3_step (stmt) == SQLITE_ROW) {
246+ jt::Json chat;
247+ chat.setObject ();
248+ chat[" id" ] = sqlite3_column_int64 (stmt, 0 );
249+ chat[" created_at" ] = reinterpret_cast <const char *>(sqlite3_column_text (stmt, 1 ));
250+ chat[" model" ] = reinterpret_cast <const char *>(sqlite3_column_text (stmt, 2 ));
251+ chat[" title" ] = reinterpret_cast <const char *>(sqlite3_column_text (stmt, 3 ));
252+ result.getArray ().push_back (std::move (chat));
253+ }
254+ sqlite3_finalize (stmt);
255+ return result;
256+ }
257+
258+ jt::Json get_chats (sqlite3 *db) {
259+ int cs;
260+ pthread_setcancelstate (PTHREAD_CANCEL_DISABLE, &cs);
261+ jt::Json res = get_chats_impl (db);
262+ pthread_setcancelstate (cs, 0 );
263+ return res;
264+ }
265+
266+ static jt::Json get_messages_impl (sqlite3 *db, int64_t chat_id) {
267+ const char *query = " SELECT id, created_at, role, content, temperature, top_p, "
268+ " presence_penalty, frequency_penalty "
269+ " FROM messages "
270+ " WHERE chat_id = ? "
271+ " ORDER BY created_at DESC;" ;
272+ sqlite3_stmt *stmt;
273+ jt::Json result;
274+ result.setArray ();
275+ if (sqlite3_prepare_v2 (db, query, -1 , &stmt, nullptr ) != SQLITE_OK) {
276+ return result;
277+ }
278+ if (sqlite3_bind_int64 (stmt, 1 , chat_id) != SQLITE_OK) {
279+ sqlite3_finalize (stmt);
280+ return result;
281+ }
282+ while (sqlite3_step (stmt) == SQLITE_ROW) {
283+ jt::Json msg;
284+ msg.setObject ();
285+ msg[" id" ] = sqlite3_column_int64 (stmt, 0 );
286+ msg[" created_at" ] = reinterpret_cast <const char *>(sqlite3_column_text (stmt, 1 ));
287+ msg[" role" ] = reinterpret_cast <const char *>(sqlite3_column_text (stmt, 2 ));
288+ msg[" content" ] = reinterpret_cast <const char *>(sqlite3_column_text (stmt, 3 ));
289+ msg[" temperature" ] = sqlite3_column_double (stmt, 4 );
290+ msg[" top_p" ] = sqlite3_column_double (stmt, 5 );
291+ msg[" presence_penalty" ] = sqlite3_column_double (stmt, 6 );
292+ msg[" frequency_penalty" ] = sqlite3_column_double (stmt, 7 );
293+ result.getArray ().push_back (std::move (msg));
294+ }
295+ sqlite3_finalize (stmt);
296+ return result;
297+ }
298+
299+ jt::Json get_messages (sqlite3 *db, int64_t chat_id) {
300+ int cs;
301+ pthread_setcancelstate (PTHREAD_CANCEL_DISABLE, &cs);
302+ jt::Json res = get_messages_impl (db, chat_id);
303+ pthread_setcancelstate (cs, 0 );
304+ return res;
305+ }
306+
307+ static jt::Json get_chat_impl (sqlite3 *db, int64_t chat_id) {
308+ const char *query = " SELECT id, created_at, model, title FROM chats WHERE id = ?;" ;
309+ sqlite3_stmt *stmt;
310+ jt::Json result;
311+ result.setObject ();
312+ if (sqlite3_prepare_v2 (db, query, -1 , &stmt, nullptr ) != SQLITE_OK) {
313+ return result;
314+ }
315+ if (sqlite3_bind_int64 (stmt, 1 , chat_id) != SQLITE_OK) {
316+ sqlite3_finalize (stmt);
317+ return result;
318+ }
319+ if (sqlite3_step (stmt) == SQLITE_ROW) {
320+ result[" id" ] = sqlite3_column_int64 (stmt, 0 );
321+ result[" created_at" ] = reinterpret_cast <const char *>(sqlite3_column_text (stmt, 1 ));
322+ result[" model" ] = reinterpret_cast <const char *>(sqlite3_column_text (stmt, 2 ));
323+ result[" title" ] = reinterpret_cast <const char *>(sqlite3_column_text (stmt, 3 ));
324+ }
325+ sqlite3_finalize (stmt);
326+ return result;
327+ }
328+
329+ jt::Json get_chat (sqlite3 *db, int64_t chat_id) {
330+ int cs;
331+ pthread_setcancelstate (PTHREAD_CANCEL_DISABLE, &cs);
332+ jt::Json res = get_chat_impl (db, chat_id);
333+ pthread_setcancelstate (cs, 0 );
334+ return res;
335+ }
336+
337+ static jt::Json get_message_impl (sqlite3 *db, int64_t message_id) {
338+ const char *query = " SELECT id, created_at, chat_id, role, content, temperature, top_p, "
339+ " presence_penalty, frequency_penalty "
340+ " FROM messages WHERE id = ?"
341+ " ORDER BY created_at ASC;" ;
342+ sqlite3_stmt *stmt;
343+ jt::Json result;
344+ result.setObject ();
345+ if (sqlite3_prepare_v2 (db, query, -1 , &stmt, nullptr ) != SQLITE_OK) {
346+ return result;
347+ }
348+ if (sqlite3_bind_int64 (stmt, 1 , message_id) != SQLITE_OK) {
349+ sqlite3_finalize (stmt);
350+ return result;
351+ }
352+ if (sqlite3_step (stmt) == SQLITE_ROW) {
353+ result[" id" ] = sqlite3_column_int64 (stmt, 0 );
354+ result[" created_at" ] = reinterpret_cast <const char *>(sqlite3_column_text (stmt, 1 ));
355+ result[" chat_id" ] = sqlite3_column_int64 (stmt, 2 );
356+ result[" role" ] = reinterpret_cast <const char *>(sqlite3_column_text (stmt, 3 ));
357+ result[" content" ] = reinterpret_cast <const char *>(sqlite3_column_text (stmt, 4 ));
358+ result[" temperature" ] = sqlite3_column_double (stmt, 5 );
359+ result[" top_p" ] = sqlite3_column_double (stmt, 6 );
360+ result[" presence_penalty" ] = sqlite3_column_double (stmt, 7 );
361+ result[" frequency_penalty" ] = sqlite3_column_double (stmt, 8 );
362+ }
363+ sqlite3_finalize (stmt);
364+ return result;
365+ }
366+
367+ jt::Json get_message (sqlite3 *db, int64_t message_id) {
368+ int cs;
369+ pthread_setcancelstate (PTHREAD_CANCEL_DISABLE, &cs);
370+ jt::Json res = get_message_impl (db, message_id);
371+ pthread_setcancelstate (cs, 0 );
372+ return res;
95373}
96374
97375} // namespace db
98- } // namespace llamafile
376+ } // namespace lf
0 commit comments