Skip to content

Commit 93c842f

Browse files
authored
feat: integration tests for pgvector (#495)
1 parent 27f553d commit 93c842f

File tree

13 files changed

+264
-166
lines changed

13 files changed

+264
-166
lines changed

.github/workflows/ci.yml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,21 @@ jobs:
155155
matrix:
156156
python-version: ["3.10", "3.11", "3.12"]
157157

158+
services:
159+
postgres:
160+
image: postgres:16
161+
env:
162+
POSTGRES_USER: postgres
163+
POSTGRES_PASSWORD: postgres
164+
POSTGRES_DB: test_db
165+
ports:
166+
- 5432:5432
167+
options: >-
168+
--health-cmd "pg_isready -U postgres"
169+
--health-interval 10s
170+
--health-timeout 5s
171+
--health-retries 5
172+
158173
steps:
159174
- uses: actions/checkout@v4
160175

@@ -182,7 +197,14 @@ jobs:
182197
path: ~/nltk_data
183198
key: nltk-${{ runner.os }}
184199

200+
- name: Install PostgreSQL + pgvector
201+
run: |
202+
sudo apt-get update
203+
sudo apt-get install -y postgresql-16-pgvector
204+
185205
- name: Run Tests With Coverage
206+
env:
207+
DATABASE_URL: postgresql://postgres:postgres@localhost:5432/test_db
186208
run: |
187209
# run with coverage to not execute tests twice
188210
uv run coverage run -m pytest -v -p no:warnings --junitxml=report.xml

.libraries-whitelist.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,7 @@ chroma-hnswlib
55
rouge
66
distilabel
77
rerankers
8-
py_rust_stemmers
8+
py_rust_stemmers
9+
mirakuru
10+
psycopg
11+
pytest-postgresql

examples/document-search/chroma.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ async def main() -> None:
7373
Run the example.
7474
"""
7575
embedder = LiteLLMEmbedder(
76-
model="text-embedding-3-small",
76+
model_name="text-embedding-3-small",
7777
default_options=LiteLLMEmbedderOptions(
7878
dimensions=1024,
7979
timeout=1000,

examples/document-search/distributed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ async def main() -> None:
8484
Run the example.
8585
"""
8686
embedder = LiteLLMEmbedder(
87-
model="text-embedding-3-small",
87+
model_name="text-embedding-3-small",
8888
)
8989
vector_store = QdrantVectorStore(
9090
client=AsyncQdrantClient(

examples/document-search/otel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ async def main() -> None:
9999
Run the example.
100100
"""
101101
embedder = LiteLLMEmbedder(
102-
model="text-embedding-3-small",
102+
model_name="text-embedding-3-small",
103103
)
104104
vector_store = ChromaVectorStore(
105105
client=EphemeralClient(),

examples/document-search/pgvector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ async def main() -> None:
8282
database_url = "postgresql://ragbits_example:ragbits_example@localhost/ragbits_example"
8383
async with asyncpg.create_pool(dsn=database_url) as pool:
8484
embedder = LiteLLMEmbedder(
85-
model="text-embedding-3-small",
85+
model_name="text-embedding-3-small",
8686
)
8787
vector_store = PgVectorStore(embedder=embedder, client=pool, table_name="example", vector_size=1536)
8888
document_search = DocumentSearch(

examples/document-search/qdrant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ async def main() -> None:
7272
Run the example.
7373
"""
7474
embedder = LiteLLMEmbedder(
75-
model="text-embedding-3-small",
75+
model_name="text-embedding-3-small",
7676
)
7777
vector_store = QdrantVectorStore(
7878
client=AsyncQdrantClient(location=":memory:"),

packages/ragbits-core/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# CHANGELOG
22

33
## Unreleased
4+
5+
- Image embeddings in PgVectorStore (#495)
6+
- Add PgVectorStore to vector store integration tests (#495)
47
- Add new fusion strategies for the hybrid vector store: RRF and DBSF (#413)
58
- move sources from ragbits-document-search to ragbits-core (#496)
69
- adding connection check to Azure get_blob_service (#502)

packages/ragbits-core/src/ragbits/core/vector_stores/pgvector.py

Lines changed: 56 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,15 @@ def _create_list_query(
178178
]
179179
return query, values
180180

181+
async def _check_table_exists(self) -> bool:
182+
check_table_existence = """
183+
SELECT EXISTS (
184+
SELECT FROM information_schema.tables
185+
WHERE table_name = $1
186+
); """
187+
async with self._client.acquire() as conn:
188+
return await conn.fetchval(check_table_existence, self._table_name)
189+
181190
async def create_table(self) -> None:
182191
"""
183192
Create a pgVector table with an HNSW index for given similarity.
@@ -188,43 +197,36 @@ async def create_table(self) -> None:
188197
vector_size=self._vector_size,
189198
hnsw_index_parameters=self._hnsw_params,
190199
):
191-
check_table_existence = """
192-
SELECT EXISTS (
193-
SELECT FROM information_schema.tables
194-
WHERE table_name = $1
195-
); """
196200
distance = DISTANCE_OPS[self._distance_method].function_name
197201
create_vector_extension = "CREATE EXTENSION IF NOT EXISTS vector;"
198202
# _table_name and has been validated in the class constructor, and it is a valid table name.
199203
# _vector_size has been validated in the class constructor, and it is a valid vector size.
200204

201205
create_table_query = f"""
202206
CREATE TABLE {self._table_name}
203-
(id UUID, key TEXT, vector VECTOR({self._vector_size}), metadata JSONB);
207+
(id UUID, text TEXT, image_bytes BYTEA, vector VECTOR({self._vector_size}), metadata JSONB);
204208
"""
205209
# _hnsw_params has been validated in the class constructor, and it is valid dict[str,int].
206210
create_index_query = f"""
207211
CREATE INDEX {self._table_name + "_hnsw_idx"} ON {self._table_name}
208212
USING hnsw (vector {distance})
209213
WITH (m = {self._hnsw_params["m"]}, ef_construction = {self._hnsw_params["ef_construction"]});
210214
"""
211-
215+
if await self._check_table_exists():
216+
print(f"Table {self._table_name} already exist!")
217+
return
212218
async with self._client.acquire() as conn:
213219
await conn.execute(create_vector_extension)
214-
exists = await conn.fetchval(check_table_existence, self._table_name)
215220

216-
if not exists:
217-
try:
218-
async with conn.transaction():
219-
await conn.execute(create_table_query)
220-
await conn.execute(create_index_query)
221+
try:
222+
async with conn.transaction():
223+
await conn.execute(create_table_query)
224+
await conn.execute(create_index_query)
221225

222-
print("Table and index created!")
223-
except Exception as e:
224-
print(f"Failed to create table and index: {e}")
225-
raise
226-
else:
227-
print("Table already exists!")
226+
print("Table and index created!")
227+
except Exception as e:
228+
print(f"Failed to create table and index: {e}")
229+
raise
228230

229231
async def store(self, entries: list[VectorStoreEntry]) -> None:
230232
"""
@@ -237,8 +239,8 @@ async def store(self, entries: list[VectorStoreEntry]) -> None:
237239
return
238240
# _table_name has been validated in the class constructor, and it is a valid table name.
239241
insert_query = f"""
240-
INSERT INTO {self._table_name} (id, key, vector, metadata)
241-
VALUES ($1, $2, $3, $4)
242+
INSERT INTO {self._table_name} (id, text, image_bytes, vector, metadata)
243+
VALUES ($1, $2, $3, $4, $5)
242244
""" # noqa S608
243245
with trace(
244246
table_name=self._table_name,
@@ -248,30 +250,28 @@ async def store(self, entries: list[VectorStoreEntry]) -> None:
248250
embedding_type=self._embedding_type,
249251
):
250252
embeddings = await self._create_embeddings(entries)
251-
252-
try:
253-
async with self._client.acquire() as conn:
254-
for entry in entries:
255-
if entry.id not in embeddings:
256-
continue
257-
258-
await conn.execute(
259-
insert_query,
260-
str(entry.id),
261-
entry.text,
262-
str(embeddings[entry.id]),
263-
json.dumps(entry.metadata, default=pydantic_encoder),
264-
)
265-
except asyncpg.exceptions.UndefinedTableError:
253+
exists = await self._check_table_exists()
254+
if not exists:
266255
print(f"Table {self._table_name} does not exist. Creating the table.")
267256
try:
268257
await self.create_table()
269258
except Exception as e:
270259
print(f"Failed to handle missing table: {e}")
271260
return
272261

273-
print("Table created successfully. Inserting entries...")
274-
await self.store(entries)
262+
async with self._client.acquire() as conn:
263+
for entry in entries:
264+
if entry.id not in embeddings:
265+
continue
266+
267+
await conn.execute(
268+
insert_query,
269+
str(entry.id),
270+
entry.text,
271+
entry.image_bytes,
272+
str(embeddings[entry.id]),
273+
json.dumps(entry.metadata, default=pydantic_encoder),
274+
)
275275

276276
async def remove(self, ids: list[UUID]) -> None:
277277
"""
@@ -296,34 +296,6 @@ async def remove(self, ids: list[UUID]) -> None:
296296
print(f"Table {self._table_name} does not exist.")
297297
return
298298

299-
async def _fetch_records(self, query: str, values: list[Any]) -> list[VectorStoreEntry]:
300-
"""
301-
Fetch records from the pgVector collection.
302-
303-
Args:
304-
query: sql query
305-
values: list of values to be used in the query.
306-
307-
Returns:
308-
list of VectorStoreEntry objects.
309-
"""
310-
try:
311-
async with self._client.acquire() as conn:
312-
results = await conn.fetch(query, *values)
313-
314-
return [
315-
VectorStoreEntry(
316-
id=record["id"],
317-
text=record["key"],
318-
metadata=json.loads(record["metadata"]),
319-
)
320-
for record in results
321-
]
322-
323-
except asyncpg.exceptions.UndefinedTableError:
324-
print(f"Table {self._table_name} does not exist.")
325-
return []
326-
327299
async def retrieve(
328300
self,
329301
text: str,
@@ -362,7 +334,8 @@ async def retrieve(
362334
VectorStoreResult(
363335
entry=VectorStoreEntry(
364336
id=record["id"],
365-
text=record["key"],
337+
text=record["text"],
338+
image_bytes=record["image_bytes"],
366339
metadata=json.loads(record["metadata"]),
367340
),
368341
vector=json.loads(record["vector"]),
@@ -393,5 +366,20 @@ async def list(
393366
"""
394367
with trace(table=self._table_name, query=where, limit=limit, offset=offset) as outputs:
395368
list_query, values = self._create_list_query(where, limit, offset)
396-
outputs.listed_entries = await self._fetch_records(list_query, values)
369+
try:
370+
async with self._client.acquire() as conn:
371+
results = await conn.fetch(list_query, *values)
372+
outputs.listed_entries = [
373+
VectorStoreEntry(
374+
id=record["id"],
375+
text=record["text"],
376+
image_bytes=record["image_bytes"],
377+
metadata=json.loads(record["metadata"]),
378+
)
379+
for record in results
380+
]
381+
382+
except asyncpg.exceptions.UndefinedTableError:
383+
print(f"Table {self._table_name} does not exist.")
384+
outputs.listed_entries = []
397385
return outputs.listed_entries

packages/ragbits-core/tests/integration/vector_stores/test_vector_store.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
from typing import cast
44
from uuid import UUID
55

6+
import asyncpg
67
import pytest
78
from chromadb import EphemeralClient
9+
from psycopg import Connection
810
from qdrant_client import AsyncQdrantClient
911

1012
from ragbits.core.embeddings.noop import NoopEmbedder
@@ -18,6 +20,7 @@
1820
)
1921
from ragbits.core.vector_stores.chroma import ChromaVectorStore
2022
from ragbits.core.vector_stores.in_memory import InMemoryVectorStore
23+
from ragbits.core.vector_stores.pgvector import PgVectorStore
2124
from ragbits.core.vector_stores.qdrant import QdrantVectorStore
2225
from ragbits.document_search import DocumentSearch
2326
from ragbits.document_search.documents.document import DocumentMeta
@@ -34,22 +37,31 @@
3437
IMAGES_PATH = Path(__file__).parent.parent.parent / "assets" / "img"
3538

3639

37-
# TODO: Add PgVectorStore
40+
@pytest.fixture
41+
async def pgvector_test_db(postgresql: Connection) -> asyncpg.pool:
42+
dsn = f"postgresql://{postgresql.info.user}:{postgresql.info.password}@{postgresql.info.host}:{postgresql.info.port}/{postgresql.info.dbname}"
43+
async with asyncpg.create_pool(dsn) as pool:
44+
yield pool
45+
46+
3847
@pytest.fixture(
3948
name="vector_store_cls",
4049
params=[
41-
lambda: partial(InMemoryVectorStore),
42-
lambda: partial(ChromaVectorStore, client=EphemeralClient(), index_name="test_index_name"),
43-
lambda: partial(QdrantVectorStore, client=AsyncQdrantClient(":memory:"), index_name="test_index_name"),
50+
lambda _: partial(InMemoryVectorStore),
51+
lambda _: partial(ChromaVectorStore, client=EphemeralClient(), index_name="test_index_name"),
52+
lambda _: partial(QdrantVectorStore, client=AsyncQdrantClient(":memory:"), index_name="test_index_name"),
53+
lambda pg_pool: partial(PgVectorStore, client=pg_pool, table_name="test_index_name", vector_size=3),
4454
],
45-
ids=["InMemoryVectorStore", "ChromaVectorStore", "QdrantVectorStore"],
55+
ids=["InMemoryVectorStore", "ChromaVectorStore", "QdrantVectorStore", "PgVectorStore"],
4656
)
47-
def vector_store_cls_fixture(request: pytest.FixtureRequest) -> type[VectorStoreWithExternalEmbedder]:
57+
def vector_store_cls_fixture(
58+
request: pytest.FixtureRequest, pgvector_test_db: asyncpg.pool
59+
) -> type[VectorStoreWithExternalEmbedder]:
4860
"""
4961
Returns vector stores classes with different backends, with backend-specific parameters already set,
5062
but parameters common to VectorStoreWithExternalEmbedder left to be set.
5163
"""
52-
return request.param()
64+
return request.param(pgvector_test_db)
5365

5466

5567
@pytest.fixture(name="vector_store", params=[EmbeddingType.TEXT, EmbeddingType.IMAGE], ids=["Text", "Image"])

0 commit comments

Comments
 (0)