Skip to content

Commit 80a8c80

Browse files
authored
Merge pull request #3 from atasoglu/feat/performance-dev
Feat/performance dev
2 parents 2363209 + 46b299b commit 80a8c80

File tree

2 files changed

+187
-71
lines changed

2 files changed

+187
-71
lines changed

benchmarks/runner.py

Lines changed: 46 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -29,50 +29,51 @@ def run_benchmark_suite(
2929
)
3030
client = SQLiteVecClient(table="benchmark", db_path=db_path)
3131

32-
# Create table
33-
dim = config["dimension"]
34-
distance = config["distance"]
35-
client.create_table(dim=dim, distance=distance)
36-
37-
# Generate data
38-
texts = generate_texts(dataset_size)
39-
embeddings = generate_embeddings(dataset_size, dim)
40-
metadata = generate_metadata(dataset_size)
41-
42-
# Benchmark: Add
43-
print(f" Benchmarking add ({dataset_size} records)...")
44-
results.append(benchmark_add(client, texts, embeddings, metadata))
45-
46-
# Get rowids for subsequent operations
47-
rowids = list(range(1, dataset_size + 1))
48-
49-
# Benchmark: Get Many
50-
print(f" Benchmarking get_many ({dataset_size} records)...")
51-
results.append(benchmark_get_many(client, rowids))
52-
53-
# Benchmark: Similarity Search
54-
print(" Benchmarking similarity_search...")
55-
query_emb = [0.5] * dim
56-
iterations = config["similarity_search"]["iterations"]
57-
for top_k in config["similarity_search"]["top_k_values"]:
58-
results.append(
59-
benchmark_similarity_search(client, query_emb, top_k, iterations)
60-
)
61-
62-
# Benchmark: Update Many
63-
print(f" Benchmarking update_many ({dataset_size} records)...")
64-
new_texts = [f"updated_{i}" for i in range(dataset_size)]
65-
results.append(benchmark_update_many(client, rowids, new_texts))
66-
67-
# Benchmark: Get All
68-
print(f" Benchmarking get_all ({dataset_size} records)...")
69-
batch_size = config["batch_size"]
70-
results.append(benchmark_get_all(client, dataset_size, batch_size))
71-
72-
# Benchmark: Delete Many
73-
print(f" Benchmarking delete_many ({dataset_size} records)...")
74-
results.append(benchmark_delete_many(client, rowids))
75-
76-
client.close()
32+
try:
33+
# Create table
34+
dim = config["dimension"]
35+
distance = config["distance"]
36+
client.create_table(dim=dim, distance=distance)
37+
38+
# Generate data
39+
texts = generate_texts(dataset_size)
40+
embeddings = generate_embeddings(dataset_size, dim)
41+
metadata = generate_metadata(dataset_size)
42+
43+
# Benchmark: Add
44+
print(f" Benchmarking add ({dataset_size} records)...")
45+
results.append(benchmark_add(client, texts, embeddings, metadata))
46+
47+
# Get rowids for subsequent operations
48+
rowids = list(range(1, dataset_size + 1))
49+
50+
# Benchmark: Get Many
51+
print(f" Benchmarking get_many ({dataset_size} records)...")
52+
results.append(benchmark_get_many(client, rowids))
53+
54+
# Benchmark: Similarity Search
55+
print(" Benchmarking similarity_search...")
56+
query_emb = [0.5] * dim
57+
iterations = config["similarity_search"]["iterations"]
58+
for top_k in config["similarity_search"]["top_k_values"]:
59+
results.append(
60+
benchmark_similarity_search(client, query_emb, top_k, iterations)
61+
)
62+
63+
# Benchmark: Update Many
64+
print(f" Benchmarking update_many ({dataset_size} records)...")
65+
new_texts = [f"updated_{i}" for i in range(dataset_size)]
66+
results.append(benchmark_update_many(client, rowids, new_texts))
67+
68+
# Benchmark: Get All
69+
print(f" Benchmarking get_all ({dataset_size} records)...")
70+
batch_size = config["batch_size"]
71+
results.append(benchmark_get_all(client, dataset_size, batch_size))
72+
73+
# Benchmark: Delete Many
74+
print(f" Benchmarking delete_many ({dataset_size} records)...")
75+
results.append(benchmark_delete_many(client, rowids))
76+
finally:
77+
client.close()
7778

7879
return results

sqlite_vec_client/base.py

Lines changed: 141 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ def create_connection(db_path: str) -> sqlite3.Connection:
6363
connection.enable_load_extension(True)
6464
sqlite_vec.load(connection)
6565
connection.enable_load_extension(False)
66+
67+
# Performance optimizations
68+
connection.execute("PRAGMA journal_mode=WAL")
69+
connection.execute("PRAGMA synchronous=NORMAL")
70+
connection.execute("PRAGMA cache_size=-64000") # 64MB cache
71+
connection.execute("PRAGMA temp_store=MEMORY")
72+
6673
logger.info(f"Successfully connected to database: {db_path}")
6774
return connection
6875
except sqlite3.Error as e:
@@ -262,31 +269,33 @@ def add(
262269
validate_embeddings_match(texts, embeddings, metadata)
263270
logger.debug(f"Adding {len(texts)} records to table '{self.table}'")
264271
try:
265-
max_id = self.connection.execute(
266-
f"SELECT max(rowid) as rowid FROM {self.table}"
267-
).fetchone()["rowid"]
268-
269-
if max_id is None:
270-
max_id = 0
271-
272272
if metadata is None:
273273
metadata = [dict() for _ in texts]
274274

275275
data_input = [
276276
(text, json.dumps(md), serialize_f32(embedding))
277277
for text, md, embedding in zip(texts, metadata, embeddings)
278278
]
279-
self.connection.executemany(
279+
280+
cur = self.connection.cursor()
281+
282+
# Get max rowid before insert
283+
max_before = cur.execute(
284+
f"SELECT COALESCE(MAX(rowid), 0) FROM {self.table}"
285+
).fetchone()[0]
286+
287+
cur.executemany(
280288
f"""INSERT INTO {self.table}(text, metadata, text_embedding)
281289
VALUES (?,?,?)""",
282290
data_input,
283291
)
292+
293+
# Calculate rowids from max_before
294+
rowids = list(range(max_before + 1, max_before + len(texts) + 1))
295+
284296
if not self._in_transaction:
285297
self.connection.commit()
286-
results = self.connection.execute(
287-
f"SELECT rowid FROM {self.table} WHERE rowid > {max_id}"
288-
)
289-
rowids = [row["rowid"] for row in results]
298+
290299
logger.info(f"Added {len(rowids)} records to table '{self.table}'")
291300
return rowids
292301
except sqlite3.OperationalError as e:
@@ -447,15 +456,25 @@ def delete_many(self, rowids: list[int]) -> int:
447456
if not rowids:
448457
return 0
449458
logger.debug(f"Deleting {len(rowids)} records")
450-
placeholders = ",".join(["?"] * len(rowids))
459+
460+
# SQLite has a limit on SQL variables (typically 999 or 32766)
461+
# Split into chunks to avoid "too many SQL variables" error
462+
chunk_size = 500
451463
cur = self.connection.cursor()
452-
cur.execute(
453-
f"DELETE FROM {self.table} WHERE rowid IN ({placeholders})",
454-
rowids,
455-
)
464+
deleted_count = 0
465+
466+
for i in range(0, len(rowids), chunk_size):
467+
chunk = rowids[i : i + chunk_size]
468+
placeholders = ",".join(["?"] * len(chunk))
469+
cur.execute(
470+
f"DELETE FROM {self.table} WHERE rowid IN ({placeholders})",
471+
chunk,
472+
)
473+
deleted_count += cur.rowcount
474+
456475
if not self._in_transaction:
457476
self.connection.commit()
458-
deleted_count = cur.rowcount
477+
459478
logger.info(f"Deleted {deleted_count} records from table '{self.table}'")
460479
return deleted_count
461480

@@ -475,10 +494,93 @@ def update_many(
475494
if not updates:
476495
return 0
477496
logger.debug(f"Updating {len(updates)} records")
478-
updated_count = 0
497+
498+
# Group updates by which fields are being updated
499+
text_updates = []
500+
metadata_updates = []
501+
embedding_updates = []
502+
full_updates = []
503+
504+
mixed_updates = []
505+
479506
for rowid, text, metadata, embedding in updates:
480-
if self.update(rowid, text=text, metadata=metadata, embedding=embedding):
481-
updated_count += 1
507+
has_text = text is not None
508+
has_metadata = metadata is not None
509+
has_embedding = embedding is not None
510+
511+
if has_text and has_metadata and has_embedding:
512+
if text is not None and metadata is not None and embedding is not None:
513+
full_updates.append(
514+
(text, json.dumps(metadata), serialize_f32(embedding), rowid)
515+
)
516+
elif has_text and not has_metadata and not has_embedding:
517+
text_updates.append((text, rowid))
518+
elif has_metadata and not has_text and not has_embedding:
519+
metadata_updates.append((json.dumps(metadata), rowid))
520+
elif has_embedding and not has_text and not has_metadata:
521+
if embedding is not None:
522+
embedding_updates.append((serialize_f32(embedding), rowid))
523+
else:
524+
# Mixed updates - store for individual execution
525+
mixed_updates.append((rowid, text, metadata, embedding))
526+
527+
cur = self.connection.cursor()
528+
updated_count = 0
529+
530+
# Batch execute grouped updates
531+
if full_updates:
532+
cur.executemany(
533+
f"""
534+
UPDATE {self.table}
535+
SET text = ?, metadata = ?, text_embedding = ? WHERE rowid = ?
536+
""",
537+
full_updates,
538+
)
539+
updated_count += cur.rowcount
540+
541+
if text_updates:
542+
cur.executemany(
543+
f"UPDATE {self.table} SET text = ? WHERE rowid = ?", text_updates
544+
)
545+
updated_count += cur.rowcount
546+
547+
if metadata_updates:
548+
cur.executemany(
549+
f"UPDATE {self.table} SET metadata = ? WHERE rowid = ?",
550+
metadata_updates,
551+
)
552+
updated_count += cur.rowcount
553+
554+
if embedding_updates:
555+
cur.executemany(
556+
f"UPDATE {self.table} SET text_embedding = ? WHERE rowid = ?",
557+
embedding_updates,
558+
)
559+
updated_count += cur.rowcount
560+
561+
# Handle mixed updates individually
562+
for rowid, text, metadata, embedding in mixed_updates:
563+
sets = []
564+
params: list[Any] = []
565+
if text is not None:
566+
sets.append("text = ?")
567+
params.append(text)
568+
if metadata is not None:
569+
sets.append("metadata = ?")
570+
params.append(json.dumps(metadata))
571+
if embedding is not None:
572+
sets.append("text_embedding = ?")
573+
params.append(serialize_f32(embedding))
574+
params.append(rowid)
575+
576+
if sets:
577+
sql = f"UPDATE {self.table} SET " + ", ".join(sets) + " WHERE rowid = ?"
578+
cur.execute(sql, params)
579+
updated_count += cur.rowcount
580+
581+
if not self._in_transaction:
582+
self.connection.commit()
583+
482584
logger.info(f"Updated {updated_count} records in table '{self.table}'")
483585
return updated_count
484586

@@ -493,13 +595,26 @@ def get_all(self, batch_size: int = 100) -> Generator[Result, None, None]:
493595
"""
494596
validate_limit(batch_size)
495597
logger.debug(f"Fetching all records with batch_size={batch_size}")
496-
offset = 0
598+
last_rowid = 0
599+
cursor = self.connection.cursor()
600+
497601
while True:
498-
batch = self.list_results(limit=batch_size, offset=offset)
499-
if not batch:
602+
cursor.execute(
603+
f"""
604+
SELECT rowid, text, metadata, text_embedding FROM {self.table}
605+
WHERE rowid > ?
606+
ORDER BY rowid ASC
607+
LIMIT ?
608+
""",
609+
[last_rowid, batch_size],
610+
)
611+
rows = cursor.fetchall()
612+
if not rows:
500613
break
501-
yield from batch
502-
offset += batch_size
614+
615+
results = self.rows_to_results(rows)
616+
yield from results
617+
last_rowid = results[-1][0] # Get last rowid from batch
503618

504619
@contextmanager
505620
def transaction(self) -> Generator[None, None, None]:

0 commit comments

Comments
 (0)