22// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
33#include " llama.cpp/llama.h"
44#include " llamafile/version.h"
5+ #include " llama.cpp/embedr/embedr.h"
56#include " llama.cpp/embedr/sqlite3.h"
67#include " llama.cpp/embedr/sqlite-vec.h"
78#include " llama.cpp/embedr/sqlite-lembed.h"
@@ -24,13 +25,18 @@ int64_t time_ms(void) {
2425
2526char * EMBEDR_MODEL = NULL ;
2627
28+ void embedr_version (sqlite3_context * context, int argc, sqlite3_value **value) {
29+ sqlite3_result_text (context, EMBEDR_VERSION, -1 , SQLITE_STATIC);
30+ }
31+
2732int embedr_sqlite3_init (sqlite3 * db) {
2833 int rc;
2934
3035 rc = sqlite3_vec_init (db, NULL , NULL ); assert (rc == SQLITE_OK);
3136 rc = sqlite3_lembed_init (db, NULL , NULL ); assert (rc == SQLITE_OK);
3237 rc = sqlite3_csv_init (db, NULL , NULL ); assert (rc == SQLITE_OK);
3338 rc = sqlite3_lines_init (db, NULL , NULL ); assert (rc == SQLITE_OK);
39+ rc = sqlite3_create_function_v2 (db, " embedr_version" ,0 , SQLITE_DETERMINISTIC | SQLITE_UTF8, NULL , embedr_version, NULL , NULL , NULL ); assert (rc == SQLITE_OK);
3440
3541 if (!EMBEDR_MODEL) {
3642 return SQLITE_OK;
@@ -87,6 +93,142 @@ void print_progress_bar(long long nEmbed, long long nTotal, long long elapsed_ms
8793 fflush (stdout);
8894}
8995
96+ int default_model_dimensions (sqlite3 * db, int64_t * dimensions) {
97+ int rc;
98+ sqlite3_stmt * stmt;
99+ rc = sqlite3_prepare_v2 (db, " select dimensions from lembed_models where name = ?" , -1 , &stmt, NULL );
100+ assert (rc == SQLITE_OK);
101+
102+ sqlite3_bind_text (stmt, 1 , " default" , -1 , SQLITE_STATIC);
103+
104+ rc = sqlite3_step (stmt);
105+ assert (rc == SQLITE_ROW);
106+ *dimensions = sqlite3_column_int64 (stmt, 0 );
107+ sqlite3_finalize (stmt);
108+
109+ return SQLITE_OK;
110+ }
111+
112+ int cmd_index (char * filename, char * target_column) {
113+ int rc;
114+ sqlite3* db = NULL ;
115+ sqlite3_stmt* stmt = NULL ;
116+ char * zDbPath = sqlite3_mprintf (" %s.db" , filename);
117+ assert (zDbPath);
118+
119+ rc = sqlite3_open (zDbPath, &db);
120+ assert (rc == SQLITE_OK);
121+
122+ rc = sqlite3_exec (db, " PRAGMA page_size=16384;" , NULL , NULL , NULL );
123+ assert (rc == SQLITE_OK);
124+
125+ rc = embedr_sqlite3_init (db);
126+ assert (rc == SQLITE_OK);
127+
128+ if (sqlite3_strlike (" %.csv" , filename, 0 ) == 0 ) {
129+ const char * zSql;
130+
131+ rc = sqlite3_exec (db, " BEGIN;" , NULL , NULL , NULL );
132+ assert (rc == SQLITE_OK);
133+
134+ zSql = sqlite3_mprintf (
135+ " CREATE VIRTUAL TABLE temp.source USING csv(filename=\" %w\" , header=yes)" ,
136+ filename
137+ );
138+ assert (zSql);
139+ rc = sqlite3_prepare_v2 (db, zSql, -1 , &stmt, NULL );
140+ assert (rc == SQLITE_OK);
141+ rc = sqlite3_step (stmt);
142+ assert (rc == SQLITE_DONE);
143+ sqlite3_finalize (stmt);
144+
145+ int64_t dimensions;
146+ rc = default_model_dimensions (db, &dimensions);
147+
148+ rc = sqlite3_exec (db, " CREATE TABLE source AS SELECT * FROM temp.source;" , NULL , NULL , NULL );
149+ assert (rc == SQLITE_OK);
150+
151+ zSql = sqlite3_mprintf (
152+ " CREATE VIRTUAL TABLE vec_source USING vec0(embedding float[%lld])" ,
153+ dimensions
154+ );
155+ assert (zSql);
156+ rc = sqlite3_prepare_v2 (db, zSql, -1 , &stmt, NULL );
157+ assert (rc == SQLITE_OK);
158+ rc = sqlite3_step (stmt);
159+ assert (rc == SQLITE_DONE);
160+ sqlite3_finalize (stmt);
161+
162+ int64_t nTotal;
163+ {
164+ sqlite3_stmt * stmt;
165+ rc = sqlite3_prepare_v2 (db, " SELECT count(*) FROM source" , -1 , &stmt, NULL );
166+ assert (rc == SQLITE_OK);
167+ rc = sqlite3_step (stmt);
168+ assert (rc == SQLITE_ROW);
169+ nTotal = sqlite3_column_int64 (stmt, 0 );
170+ sqlite3_finalize (stmt);
171+ }
172+
173+ int64_t nRemaining = nTotal;
174+
175+
176+ zSql = sqlite3_mprintf (
177+ " \
178+ WITH chunk AS ( \
179+ SELECT \
180+ source.rowid, \
181+ lembed(source.\" %w\" ) AS embedding \
182+ FROM source \
183+ WHERE source.rowid NOT IN (select rowid from vec_source) \
184+ LIMIT 256 \
185+ ) \
186+ INSERT INTO vec_source(rowid, embedding) \
187+ SELECT rowid, embedding FROM chunk \
188+ RETURNING rowid; \
189+ " ,
190+ target_column
191+ );
192+ assert (zSql);
193+
194+ rc = sqlite3_prepare_v2 (db, zSql, -1 , &stmt, NULL );
195+ assert (rc == SQLITE_OK);
196+
197+ int64_t nEmbed = 0 ;
198+ int64_t t0 = time_ms ();
199+
200+ while (1 ){
201+ sqlite3_reset (stmt);
202+
203+ int nChunkEmbed = 0 ;
204+ while (1 ) {
205+ rc = sqlite3_step (stmt);
206+ if (rc == SQLITE_DONE) {
207+ break ;
208+ }
209+ assert (rc == SQLITE_ROW);
210+ nChunkEmbed++;
211+ }
212+ if (nChunkEmbed == 0 ) {
213+ break ;
214+ }
215+ nEmbed += nChunkEmbed;
216+ nRemaining -= nChunkEmbed;
217+ print_progress_bar (nEmbed, nTotal, time_ms () - t0);
218+ }
219+ }
220+ else {
221+ printf (" Unknown filetype\n " );
222+ }
223+
224+ rc = sqlite3_exec (db, " COMMIT;" , NULL , NULL , NULL );
225+ assert (rc == SQLITE_OK);
226+
227+ sqlite3_free (zDbPath);
228+ sqlite3_close (db);
229+ return SQLITE_OK;
230+ }
231+
90232int cmd_backfill (char * dbPath, char * table, char * column) {
91233 int rc;
92234 sqlite3* db;
@@ -252,7 +394,8 @@ int main(int argc, char ** argv) {
252394 }
253395 else if (sqlite3_stricmp (arg, " --version" ) == 0 || sqlite3_stricmp (arg, " -v" ) == 0 ) {
254396 fprintf (stderr,
255- " llamafile-embed %s, SQLite %s, sqlite-vec=%s, sqlite-lembed=%s\n " ,
397+ " embedr %s, llamafile %s, SQLite %s, sqlite-vec=%s, sqlite-lembed=%s\n " ,
398+ EMBEDR_VERSION,
256399 LLAMAFILE_VERSION_STRING,
257400 sqlite3_version,
258401 SQLITE_VEC_VERSION,
@@ -268,12 +411,18 @@ int main(int argc, char ** argv) {
268411 return cmd_embed (argv[i+1 ]);
269412 }
270413 else if (sqlite3_stricmp (arg, " backfill" ) == 0 ) {
271- assert (i + 5 == argc);
414+ assert (i + 4 == argc);
272415 char * dbpath = argv[i+1 ];
273416 char * table = argv[i+2 ];
274417 char * column = argv[i+3 ];
275418 return cmd_backfill (dbpath, table, column);
276419 }
420+ else if (sqlite3_stricmp (arg, " index" ) == 0 ) {
421+ assert (i + 3 == argc);
422+ char * path = argv[i+1 ];
423+ char * column = argv[i+2 ];
424+ return cmd_index (path, column);
425+ }
277426 else {
278427 printf (" Unknown arg %s\n " , arg);
279428 return 1 ;
0 commit comments