Skip to content

Commit 88873e5

Browse files
committed
Improved tests for async SQLAlchemy engines [skip ci]
1 parent c792451 commit 88873e5

File tree

1 file changed

+28
-30
lines changed

1 file changed

+28
-30
lines changed

tests/test_sqlalchemy.py

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def psycopg2_connect(dbapi_connection, connection_record):
3131
engines = [psycopg2_engine, pg8000_engine, psycopg2_type_engine]
3232
array_engines = [psycopg2_type_engine]
3333
async_engines = []
34+
async_array_engines = []
3435

3536
if sqlalchemy_version > 1:
3637
psycopg_engine = create_engine('postgresql+psycopg://localhost/pgvector_python_test')
@@ -46,11 +47,32 @@ def psycopg_connect(dbapi_connection, connection_record):
4647
engines.append(psycopg_type_engine)
4748
array_engines.append(psycopg_type_engine)
4849

50+
psycopg_async_type_engine = create_async_engine('postgresql+psycopg://localhost/pgvector_python_test')
51+
52+
@event.listens_for(psycopg_async_type_engine.sync_engine, "connect")
53+
def connect(dbapi_connection, connection_record):
54+
from pgvector.psycopg import register_vector_async
55+
dbapi_connection.run_async(register_vector_async)
56+
57+
async_engines.append(psycopg_async_type_engine)
58+
async_array_engines.append(psycopg_async_type_engine)
59+
4960
psycopg_async_engine = create_async_engine('postgresql+psycopg://localhost/pgvector_python_test')
5061
async_engines.append(psycopg_async_engine)
5162

5263
asyncpg_engine = create_async_engine('postgresql+asyncpg://localhost/pgvector_python_test')
5364
async_engines.append(asyncpg_engine)
65+
async_array_engines.append(asyncpg_engine)
66+
67+
asyncpg_type_engine = create_async_engine('postgresql+asyncpg://localhost/pgvector_python_test')
68+
69+
@event.listens_for(asyncpg_type_engine.sync_engine, "connect")
70+
def connect(dbapi_connection, connection_record):
71+
from pgvector.asyncpg import register_vector
72+
dbapi_connection.run_async(register_vector)
73+
74+
# TODO do not throw error when types are registered
75+
# async_array_engines.append(asyncpg_type_engine)
5476

5577
setup_engine = engines[0]
5678
with Session(setup_engine) as session:
@@ -599,6 +621,10 @@ async def test_sparsevec(self, engine):
599621

600622
@pytest.mark.asyncio
601623
async def test_avg(self, engine):
624+
# TODO do not throw error when types are registered
625+
if engine == psycopg_async_type_engine:
626+
return
627+
602628
async_session = async_sessionmaker(engine, expire_on_commit=False)
603629

604630
async with async_session() as session:
@@ -611,43 +637,15 @@ async def test_avg(self, engine):
611637
await engine.dispose()
612638

613639

614-
@pytest.mark.skipif(sqlalchemy_version == 1, reason='Requires SQLAlchemy 2+')
640+
@pytest.mark.parametrize('engine', async_array_engines)
615641
class TestSqlalchemyAsyncArray:
616642
def setup_method(self):
617643
delete_items()
618644

619645
@pytest.mark.asyncio
620-
async def test_psycopg_vector_array(self):
621-
engine = create_async_engine('postgresql+psycopg://localhost/pgvector_python_test')
646+
async def test_vector_array(self, engine):
622647
async_session = async_sessionmaker(engine, expire_on_commit=False)
623648

624-
@event.listens_for(engine.sync_engine, "connect")
625-
def connect(dbapi_connection, connection_record):
626-
from pgvector.psycopg import register_vector_async
627-
dbapi_connection.run_async(register_vector_async)
628-
629-
async with async_session() as session:
630-
async with session.begin():
631-
session.add(Item(id=1, embeddings=[np.array([1, 2, 3]), np.array([4, 5, 6])]))
632-
633-
# this fails if the driver does not cast arrays
634-
item = await session.get(Item, 1)
635-
assert item.embeddings[0].tolist() == [1, 2, 3]
636-
assert item.embeddings[1].tolist() == [4, 5, 6]
637-
638-
await engine.dispose()
639-
640-
@pytest.mark.asyncio
641-
async def test_asyncpg_vector_array(self):
642-
engine = create_async_engine('postgresql+asyncpg://localhost/pgvector_python_test')
643-
async_session = async_sessionmaker(engine, expire_on_commit=False)
644-
645-
# TODO do not throw error when types are registered
646-
# @event.listens_for(engine.sync_engine, "connect")
647-
# def connect(dbapi_connection, connection_record):
648-
# from pgvector.asyncpg import register_vector
649-
# dbapi_connection.run_async(register_vector)
650-
651649
async with async_session() as session:
652650
async with session.begin():
653651
session.add(Item(id=1, embeddings=[np.array([1, 2, 3]), np.array([4, 5, 6])]))

0 commit comments

Comments
 (0)