1+ """High-level client for vector search on SQLite using the sqlite-vec extension.
2+
3+ This module provides `SQLiteVecClient`, a thin wrapper around `sqlite3` and
4+ `sqlite-vec` to store texts, JSON metadata, and float32 embeddings, and to run
5+ similarity search through a virtual vector table.
6+ """
7+
18from __future__ import annotations
29import json
310import sqlite3
916
1017
1118class SQLiteVecClient :
19+ """Manage a text+embedding table and its sqlite-vec index.
20+
21+ The client maintains two tables:
22+ - `{table}`: base table with columns `text`, `metadata`, `text_embedding`.
23+ - `{table}_vec`: `vec0` virtual table mirroring embeddings for ANN search.
24+
25+ It exposes CRUD helpers and `similarity_search` over embeddings.
26+ """
27+
1228 @staticmethod
13- def create_connection (db_path : str ):
29+ def create_connection (db_path : str ) -> sqlite3 .Connection :
30+ """Create a SQLite connection with sqlite-vec extension loaded."""
1431 connection = sqlite3 .connect (db_path )
1532 connection .row_factory = sqlite3 .Row
1633 connection .enable_load_extension (True )
@@ -20,6 +37,7 @@ def create_connection(db_path: str):
2037
2138 @staticmethod
2239 def rows_to_results (rows : List [sqlite3 .Row ]) -> List [Result ]:
40+ """Convert `sqlite3.Row` items into `(rowid, text, metadata, embedding)`."""
2341 return [
2442 (
2543 row ["rowid" ],
@@ -30,11 +48,13 @@ def rows_to_results(rows: List[sqlite3.Row]) -> List[Result]:
3048 for row in rows
3149 ]
3250
33- def __init__ (self , table : str , db_path : str ):
51+ def __init__ (self , table : str , db_path : str ) -> None :
52+ """Initialize the client for a given base table and database file."""
3453 self .table = table
3554 self .connection = self .create_connection (db_path )
3655
3756 def __enter__ (self ) -> SQLiteVecClient :
57+ """Support context manager protocol and return `self`."""
3858 return self
3959
4060 def __exit__ (
@@ -43,6 +63,7 @@ def __exit__(
4363 exc : Optional [BaseException ],
4464 tb : Optional [TracebackType ],
4565 ) -> Optional [bool ]:
66+ """Close the connection on exit; do not suppress exceptions."""
4667 self .close ()
4768 return False
4869
@@ -51,6 +72,7 @@ def create_table(
5172 dim : int ,
5273 distance : Literal ["L1" , "L2" , "cosine" ] = "cosine" ,
5374 ) -> None :
75+ """Create base table, vector table, and triggers to keep them in sync."""
5476 self .connection .execute (
5577 f"""
5678 CREATE TABLE IF NOT EXISTS { self .table }
@@ -112,6 +134,7 @@ def similarity_search(
112134 embedding : Embeddings ,
113135 top_k : int = 5 ,
114136 ) -> List [SimilaritySearchResult ]:
137+ """Return top-k nearest neighbors for the given embedding."""
115138 cursor = self .connection .cursor ()
116139 cursor .execute (
117140 f"""
@@ -137,6 +160,7 @@ def add(
137160 embeddings : List [Embeddings ],
138161 metadata : List [Metadata ] = None ,
139162 ) -> Rowids :
163+ """Insert texts with embeddings (and optional metadata) and return rowids."""
140164 max_id = self .connection .execute (
141165 f"SELECT max(rowid) as rowid FROM { self .table } "
142166 ).fetchone ()["rowid" ]
@@ -162,6 +186,7 @@ def add(
162186 return [row ["rowid" ] for row in results ]
163187
164188 def get_by_id (self , rowid : int ) -> Optional [Result ]:
189+ """Get a single record by rowid; return `None` if not found."""
165190 cursor = self .connection .cursor ()
166191 cursor .execute (
167192 f"SELECT rowid, text, metadata, text_embedding FROM { self .table } WHERE rowid = ?" ,
@@ -173,6 +198,7 @@ def get_by_id(self, rowid: int) -> Optional[Result]:
173198 return self .rows_to_results ([row ])[0 ]
174199
175200 def get_many (self , rowids : List [int ]) -> List [Result ]:
201+ """Get multiple records by rowids; returns empty list if input is empty."""
176202 if not rowids :
177203 return []
178204 placeholders = "," .join (["?" ] * len (rowids ))
@@ -185,6 +211,7 @@ def get_many(self, rowids: List[int]) -> List[Result]:
185211 return self .rows_to_results (rows )
186212
187213 def get_by_text (self , text : str ) -> List [Result ]:
214+ """Get all records with exact `text`, ordered by rowid ascending."""
188215 cursor = self .connection .cursor ()
189216 cursor .execute (
190217 f"""
@@ -198,6 +225,7 @@ def get_by_text(self, text: str) -> List[Result]:
198225 return self .rows_to_results (rows )
199226
200227 def get_by_metadata (self , metadata : Dict [str , Any ]) -> List [Result ]:
228+ """Get all records whose metadata exactly equals the given dict."""
201229 cursor = self .connection .cursor ()
202230 cursor .execute (
203231 f"""
@@ -216,6 +244,7 @@ def list(
216244 offset : int = 0 ,
217245 order : Literal ["asc" , "desc" ] = "asc" ,
218246 ) -> List [Result ]:
247+ """List records with pagination and order by rowid."""
219248 cursor = self .connection .cursor ()
220249 cursor .execute (
221250 f"""
@@ -229,6 +258,7 @@ def list(
229258 return self .rows_to_results (rows )
230259
231260 def count (self ) -> int :
261+ """Return the total number of rows in the base table."""
232262 cursor = self .connection .cursor ()
233263 cursor .execute (f"SELECT COUNT(1) as c FROM { self .table } " )
234264 row = cursor .fetchone ()
@@ -242,6 +272,7 @@ def update(
242272 metadata : Optional [Metadata ] = None ,
243273 embedding : Optional [Embeddings ] = None ,
244274 ) -> bool :
275+ """Update fields of a record by rowid; return True if a row changed."""
245276 sets = []
246277 params : List [Any ] = []
247278 if text is not None :
@@ -265,12 +296,14 @@ def update(
265296 return cur .rowcount > 0
266297
267298 def delete_by_id (self , rowid : int ) -> bool :
299+ """Delete a single record by rowid; return True if a row was removed."""
268300 cur = self .connection .cursor ()
269301 cur .execute (f"DELETE FROM { self .table } WHERE rowid = ?" , [rowid ])
270302 self .connection .commit ()
271303 return cur .rowcount > 0
272304
273305 def delete_many (self , rowids : List [int ]) -> int :
306+ """Delete multiple records by rowids; return number of rows removed."""
274307 if not rowids :
275308 return 0
276309 placeholders = "," .join (["?" ] * len (rowids ))
@@ -283,6 +316,7 @@ def delete_many(self, rowids: List[int]) -> int:
283316 return cur .rowcount
284317
285318 def close (self ) -> None :
319+ """Close the underlying SQLite connection, suppressing close errors."""
286320 try :
287321 self .connection .close ()
288322 except Exception :
0 commit comments