66"""
77
88from __future__ import annotations
9+
910import json
1011import sqlite3
11- import sqlite_vec
1212from types import TracebackType
13- from typing import Optional , List , Dict , Any , Literal , Type
14- from .types import Embeddings , Result , SimilaritySearchResult , Rowids , Metadata , Text
15- from .utils import serialize_f32 , deserialize_f32
13+ from typing import Any , Literal
14+
15+ import sqlite_vec
16+
17+ from .types import Embeddings , Metadata , Result , Rowids , SimilaritySearchResult , Text
18+ from .utils import deserialize_f32 , serialize_f32
1619
1720
1821class SQLiteVecClient :
@@ -36,7 +39,7 @@ def create_connection(db_path: str) -> sqlite3.Connection:
3639 return connection
3740
3841 @staticmethod
39- def rows_to_results (rows : List [sqlite3 .Row ]) -> List [Result ]:
42+ def rows_to_results (rows : list [sqlite3 .Row ]) -> list [Result ]:
4043 """Convert `sqlite3.Row` items into `(rowid, text, metadata, embedding)`."""
4144 return [
4245 (
@@ -59,13 +62,12 @@ def __enter__(self) -> SQLiteVecClient:
5962
6063 def __exit__ (
6164 self ,
62- exc_type : Optional [ Type [ BaseException ]] ,
63- exc : Optional [ BaseException ] ,
64- tb : Optional [ TracebackType ] ,
65- ) -> Optional [ bool ] :
65+ exc_type : type [ BaseException ] | None ,
66+ exc : BaseException | None ,
67+ tb : TracebackType | None ,
68+ ) -> None :
6669 """Close the connection on exit; do not suppress exceptions."""
6770 self .close ()
68- return False
6971
7072 def create_table (
7173 self ,
@@ -96,11 +98,11 @@ def create_table(
9698 )
9799 self .connection .execute (
98100 f"""
99- CREATE TRIGGER IF NOT EXISTS { self .table } _embed_text
101+ CREATE TRIGGER IF NOT EXISTS { self .table } _embed_text
100102 AFTER INSERT ON { self .table }
101103 BEGIN
102104 INSERT INTO { self .table } _vec(rowid, text_embedding)
103- VALUES (new.rowid, new.text_embedding)
105+ VALUES (new.rowid, new.text_embedding)
104106 ;
105107 END;
106108 """
@@ -133,7 +135,7 @@ def similarity_search(
133135 self ,
134136 embedding : Embeddings ,
135137 top_k : int = 5 ,
136- ) -> List [SimilaritySearchResult ]:
138+ ) -> list [SimilaritySearchResult ]:
137139 """Return top-k nearest neighbors for the given embedding."""
138140 cursor = self .connection .cursor ()
139141 cursor .execute (
@@ -143,7 +145,7 @@ def similarity_search(
143145 text,
144146 distance
145147 FROM { self .table } AS e
146- INNER JOIN { self .table } _vec AS v on v.rowid = e.rowid
148+ INNER JOIN { self .table } _vec AS v on v.rowid = e.rowid
147149 WHERE
148150 v.text_embedding MATCH ?
149151 AND k = ?
@@ -156,9 +158,9 @@ def similarity_search(
156158
157159 def add (
158160 self ,
159- texts : List [Text ],
160- embeddings : List [Embeddings ],
161- metadata : List [Metadata ] = None ,
161+ texts : list [Text ],
162+ embeddings : list [Embeddings ],
163+ metadata : list [Metadata ] | None = None ,
162164 ) -> Rowids :
163165 """Insert texts with embeddings (and optional metadata) and return rowids."""
164166 max_id = self .connection .execute (
@@ -176,7 +178,8 @@ def add(
176178 for text , md , embedding in zip (texts , metadata , embeddings )
177179 ]
178180 self .connection .executemany (
179- f"INSERT INTO { self .table } (text, metadata, text_embedding) VALUES (?,?,?)" ,
181+ f"""INSERT INTO { self .table } (text, metadata, text_embedding)
182+ VALUES (?,?,?)""" ,
180183 data_input ,
181184 )
182185 self .connection .commit ()
@@ -185,32 +188,36 @@ def add(
185188 )
186189 return [row ["rowid" ] for row in results ]
187190
188- def get_by_id (self , rowid : int ) -> Optional [ Result ] :
191+ def get_by_id (self , rowid : int ) -> Result | None :
189192 """Get a single record by rowid; return `None` if not found."""
190193 cursor = self .connection .cursor ()
191194 cursor .execute (
192- f"SELECT rowid, text, metadata, text_embedding FROM { self .table } WHERE rowid = ?" ,
195+ f"""
196+ SELECT rowid, text, metadata, text_embedding
197+ FROM { self .table } WHERE rowid = ?
198+ """ ,
193199 [rowid ],
194200 )
195201 row = cursor .fetchone ()
196202 if row is None :
197203 return None
198204 return self .rows_to_results ([row ])[0 ]
199205
200- def get_many (self , rowids : List [int ]) -> List [Result ]:
206+ def get_many (self , rowids : list [int ]) -> list [Result ]:
201207 """Get multiple records by rowids; returns empty list if input is empty."""
202208 if not rowids :
203209 return []
204210 placeholders = "," .join (["?" ] * len (rowids ))
205211 cursor = self .connection .cursor ()
206212 cursor .execute (
207- f"SELECT rowid, text, metadata, text_embedding FROM { self .table } WHERE rowid IN ({ placeholders } )" ,
213+ f"""SELECT rowid, text, metadata, text_embedding FROM { self .table }
214+ WHERE rowid IN ({ placeholders } )""" ,
208215 rowids ,
209216 )
210217 rows = cursor .fetchall ()
211218 return self .rows_to_results (rows )
212219
213- def get_by_text (self , text : str ) -> List [Result ]:
220+ def get_by_text (self , text : str ) -> list [Result ]:
214221 """Get all records with exact `text`, ordered by rowid ascending."""
215222 cursor = self .connection .cursor ()
216223 cursor .execute (
@@ -224,7 +231,7 @@ def get_by_text(self, text: str) -> List[Result]:
224231 rows = cursor .fetchall ()
225232 return self .rows_to_results (rows )
226233
227- def get_by_metadata (self , metadata : Dict [str , Any ]) -> List [Result ]:
234+ def get_by_metadata (self , metadata : dict [str , Any ]) -> list [Result ]:
228235 """Get all records whose metadata exactly equals the given dict."""
229236 cursor = self .connection .cursor ()
230237 cursor .execute (
@@ -238,12 +245,12 @@ def get_by_metadata(self, metadata: Dict[str, Any]) -> List[Result]:
238245 rows = cursor .fetchall ()
239246 return self .rows_to_results (rows )
240247
241- def list (
248+ def list_results (
242249 self ,
243250 limit : int = 50 ,
244251 offset : int = 0 ,
245252 order : Literal ["asc" , "desc" ] = "asc" ,
246- ) -> List [Result ]:
253+ ) -> list [Result ]:
247254 """List records with pagination and order by rowid."""
248255 cursor = self .connection .cursor ()
249256 cursor .execute (
@@ -268,13 +275,13 @@ def update(
268275 self ,
269276 rowid : int ,
270277 * ,
271- text : Optional [ str ] = None ,
272- metadata : Optional [ Metadata ] = None ,
273- embedding : Optional [ Embeddings ] = None ,
278+ text : str | None = None ,
279+ metadata : Metadata | None = None ,
280+ embedding : Embeddings | None = None ,
274281 ) -> bool :
275282 """Update fields of a record by rowid; return True if a row changed."""
276283 sets = []
277- params : List [Any ] = []
284+ params : list [Any ] = []
278285 if text is not None :
279286 sets .append ("text = ?" )
280287 params .append (text )
@@ -302,7 +309,7 @@ def delete_by_id(self, rowid: int) -> bool:
302309 self .connection .commit ()
303310 return cur .rowcount > 0
304311
305- def delete_many (self , rowids : List [int ]) -> int :
312+ def delete_many (self , rowids : list [int ]) -> int :
306313 """Delete multiple records by rowids; return number of rows removed."""
307314 if not rowids :
308315 return 0
0 commit comments