Skip to content

Commit 605a352

Browse files
committed
Add docstrings to modules and methods in sqlite_vec_client for improved documentation and clarity
1 parent c40a77e commit 605a352

File tree

4 files changed

+47
-3
lines changed

4 files changed

+47
-3
lines changed

sqlite_vec_client/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
"""Public package API for the sqlite-vec client.
2+
3+
Exposes `SQLiteVecClient` as the primary entry point.
4+
"""
5+
16
from .base import SQLiteVecClient
27

38
__all__ = ["SQLiteVecClient"]

sqlite_vec_client/base.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
"""High-level client for vector search on SQLite using the sqlite-vec extension.
2+
3+
This module provides `SQLiteVecClient`, a thin wrapper around `sqlite3` and
4+
`sqlite-vec` to store texts, JSON metadata, and float32 embeddings, and to run
5+
similarity search through a virtual vector table.
6+
"""
7+
18
from __future__ import annotations
29
import json
310
import sqlite3
@@ -9,8 +16,18 @@
916

1017

1118
class SQLiteVecClient:
19+
"""Manage a text+embedding table and its sqlite-vec index.
20+
21+
The client maintains two tables:
22+
- `{table}`: base table with columns `text`, `metadata`, `text_embedding`.
23+
- `{table}_vec`: `vec0` virtual table mirroring embeddings for ANN search.
24+
25+
It exposes CRUD helpers and `similarity_search` over embeddings.
26+
"""
27+
1228
@staticmethod
13-
def create_connection(db_path: str):
29+
def create_connection(db_path: str) -> sqlite3.Connection:
30+
"""Create a SQLite connection with sqlite-vec extension loaded."""
1431
connection = sqlite3.connect(db_path)
1532
connection.row_factory = sqlite3.Row
1633
connection.enable_load_extension(True)
@@ -20,6 +37,7 @@ def create_connection(db_path: str):
2037

2138
@staticmethod
2239
def rows_to_results(rows: List[sqlite3.Row]) -> List[Result]:
40+
"""Convert `sqlite3.Row` items into `(rowid, text, metadata, embedding)`."""
2341
return [
2442
(
2543
row["rowid"],
@@ -30,11 +48,13 @@ def rows_to_results(rows: List[sqlite3.Row]) -> List[Result]:
3048
for row in rows
3149
]
3250

33-
def __init__(self, table: str, db_path: str):
51+
def __init__(self, table: str, db_path: str) -> None:
52+
"""Initialize the client for a given base table and database file."""
3453
self.table = table
3554
self.connection = self.create_connection(db_path)
3655

3756
def __enter__(self) -> SQLiteVecClient:
57+
"""Support context manager protocol and return `self`."""
3858
return self
3959

4060
def __exit__(
@@ -43,6 +63,7 @@ def __exit__(
4363
exc: Optional[BaseException],
4464
tb: Optional[TracebackType],
4565
) -> Optional[bool]:
66+
"""Close the connection on exit; do not suppress exceptions."""
4667
self.close()
4768
return False
4869

@@ -51,6 +72,7 @@ def create_table(
5172
dim: int,
5273
distance: Literal["L1", "L2", "cosine"] = "cosine",
5374
) -> None:
75+
"""Create base table, vector table, and triggers to keep them in sync."""
5476
self.connection.execute(
5577
f"""
5678
CREATE TABLE IF NOT EXISTS {self.table}
@@ -112,6 +134,7 @@ def similarity_search(
112134
embedding: Embeddings,
113135
top_k: int = 5,
114136
) -> List[SimilaritySearchResult]:
137+
"""Return top-k nearest neighbors for the given embedding."""
115138
cursor = self.connection.cursor()
116139
cursor.execute(
117140
f"""
@@ -137,6 +160,7 @@ def add(
137160
embeddings: List[Embeddings],
138161
metadata: List[Metadata] = None,
139162
) -> Rowids:
163+
"""Insert texts with embeddings (and optional metadata) and return rowids."""
140164
max_id = self.connection.execute(
141165
f"SELECT max(rowid) as rowid FROM {self.table}"
142166
).fetchone()["rowid"]
@@ -162,6 +186,7 @@ def add(
162186
return [row["rowid"] for row in results]
163187

164188
def get_by_id(self, rowid: int) -> Optional[Result]:
189+
"""Get a single record by rowid; return `None` if not found."""
165190
cursor = self.connection.cursor()
166191
cursor.execute(
167192
f"SELECT rowid, text, metadata, text_embedding FROM {self.table} WHERE rowid = ?",
@@ -173,6 +198,7 @@ def get_by_id(self, rowid: int) -> Optional[Result]:
173198
return self.rows_to_results([row])[0]
174199

175200
def get_many(self, rowids: List[int]) -> List[Result]:
201+
"""Get multiple records by rowids; returns empty list if input is empty."""
176202
if not rowids:
177203
return []
178204
placeholders = ",".join(["?"] * len(rowids))
@@ -185,6 +211,7 @@ def get_many(self, rowids: List[int]) -> List[Result]:
185211
return self.rows_to_results(rows)
186212

187213
def get_by_text(self, text: str) -> List[Result]:
214+
"""Get all records with exact `text`, ordered by rowid ascending."""
188215
cursor = self.connection.cursor()
189216
cursor.execute(
190217
f"""
@@ -198,6 +225,7 @@ def get_by_text(self, text: str) -> List[Result]:
198225
return self.rows_to_results(rows)
199226

200227
def get_by_metadata(self, metadata: Dict[str, Any]) -> List[Result]:
228+
"""Get all records whose metadata exactly equals the given dict."""
201229
cursor = self.connection.cursor()
202230
cursor.execute(
203231
f"""
@@ -216,6 +244,7 @@ def list(
216244
offset: int = 0,
217245
order: Literal["asc", "desc"] = "asc",
218246
) -> List[Result]:
247+
"""List records with pagination and order by rowid."""
219248
cursor = self.connection.cursor()
220249
cursor.execute(
221250
f"""
@@ -229,6 +258,7 @@ def list(
229258
return self.rows_to_results(rows)
230259

231260
def count(self) -> int:
261+
"""Return the total number of rows in the base table."""
232262
cursor = self.connection.cursor()
233263
cursor.execute(f"SELECT COUNT(1) as c FROM {self.table}")
234264
row = cursor.fetchone()
@@ -242,6 +272,7 @@ def update(
242272
metadata: Optional[Metadata] = None,
243273
embedding: Optional[Embeddings] = None,
244274
) -> bool:
275+
"""Update fields of a record by rowid; return True if a row changed."""
245276
sets = []
246277
params: List[Any] = []
247278
if text is not None:
@@ -265,12 +296,14 @@ def update(
265296
return cur.rowcount > 0
266297

267298
def delete_by_id(self, rowid: int) -> bool:
299+
"""Delete a single record by rowid; return True if a row was removed."""
268300
cur = self.connection.cursor()
269301
cur.execute(f"DELETE FROM {self.table} WHERE rowid = ?", [rowid])
270302
self.connection.commit()
271303
return cur.rowcount > 0
272304

273305
def delete_many(self, rowids: List[int]) -> int:
306+
"""Delete multiple records by rowids; return number of rows removed."""
274307
if not rowids:
275308
return 0
276309
placeholders = ",".join(["?"] * len(rowids))
@@ -283,6 +316,7 @@ def delete_many(self, rowids: List[int]) -> int:
283316
return cur.rowcount
284317

285318
def close(self) -> None:
319+
"""Close the underlying SQLite connection, suppressing close errors."""
286320
try:
287321
self.connection.close()
288322
except Exception:

sqlite_vec_client/types.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from typing import List, Dict, Any, Tuple, TypeAlias
1+
"""Type aliases used across the sqlite-vec client package."""
22

3+
from typing import List, Dict, Any, Tuple, TypeAlias
34

45
Text: TypeAlias = str
56
Rowid: TypeAlias = int

sqlite_vec_client/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1+
"""Utilities for serializing and deserializing float32 embedding arrays."""
2+
13
import struct
24
from .types import Embeddings
35

46

57
def serialize_f32(embeddings: Embeddings) -> bytes:
8+
"""Serialize a list of float32 values into a packed bytes blob."""
69
return struct.pack("%sf" % len(embeddings), *embeddings)
710

811

912
def deserialize_f32(blob: bytes) -> Embeddings:
13+
"""Deserialize a bytes blob of float32 values back into a list of floats."""
1014
return list(struct.unpack("%sf" % (len(blob) // 4), blob))

0 commit comments

Comments
 (0)