Skip to content

Commit 86331f0

Browse files
committed
Improved tests for async SQLAlchemy engines [skip ci]
1 parent 2883156 commit 86331f0

File tree

1 file changed

+39
-65
lines changed

1 file changed

+39
-65
lines changed

tests/test_sqlalchemy.py

Lines changed: 39 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -544,122 +544,96 @@ def setup_method(self):
544544
delete_items()
545545

546546
@pytest.mark.asyncio
547-
async def test_psycopg_async_avg(self, engine):
547+
async def test_vector(self, engine):
548548
async_session = async_sessionmaker(engine, expire_on_commit=False)
549549

550550
async with async_session() as session:
551551
async with session.begin():
552-
session.add(Item(embedding=[1, 2, 3]))
553-
session.add(Item(embedding=[4, 5, 6]))
554-
avg = await session.scalars(select(func.avg(Item.embedding)))
555-
assert avg.first() == '[2.5,3.5,4.5]'
552+
embedding = np.array([1, 2, 3])
553+
session.add(Item(id=1, embedding=embedding))
554+
item = await session.get(Item, 1)
555+
assert np.array_equal(item.embedding, embedding)
556556

557557
await engine.dispose()
558558

559-
560-
class TestSqlalchemyAsync2:
561-
def setup_method(self):
562-
delete_items()
563-
564559
@pytest.mark.asyncio
565-
@pytest.mark.skipif(sqlalchemy_version == 1, reason='Requires SQLAlchemy 2+')
566-
async def test_psycopg_async_vector_array(self):
567-
engine = create_async_engine('postgresql+psycopg://localhost/pgvector_python_test')
560+
async def test_halfvec(self, engine):
568561
async_session = async_sessionmaker(engine, expire_on_commit=False)
569562

570-
@event.listens_for(engine.sync_engine, "connect")
571-
def connect(dbapi_connection, connection_record):
572-
from pgvector.psycopg import register_vector_async
573-
dbapi_connection.run_async(register_vector_async)
574-
575563
async with async_session() as session:
576564
async with session.begin():
577-
session.add(Item(id=1, embeddings=[np.array([1, 2, 3]), np.array([4, 5, 6])]))
578-
579-
# this fails if the driver does not cast arrays
565+
embedding = [1, 2, 3]
566+
session.add(Item(id=1, half_embedding=embedding))
580567
item = await session.get(Item, 1)
581-
assert item.embeddings[0].tolist() == [1, 2, 3]
582-
assert item.embeddings[1].tolist() == [4, 5, 6]
568+
assert item.half_embedding.to_list() == embedding
583569

584570
await engine.dispose()
585571

586572
@pytest.mark.asyncio
587-
@pytest.mark.skipif(sqlalchemy_version == 1, reason='Requires SQLAlchemy 2+')
588-
async def test_asyncpg_vector(self):
589-
engine = create_async_engine('postgresql+asyncpg://localhost/pgvector_python_test')
590-
async_session = async_sessionmaker(engine, expire_on_commit=False)
573+
async def test_bit(self, engine):
574+
import asyncpg
591575

592-
# TODO do not throw error when types are registered
593-
# @event.listens_for(engine.sync_engine, "connect")
594-
# def connect(dbapi_connection, connection_record):
595-
# from pgvector.asyncpg import register_vector
596-
# dbapi_connection.run_async(register_vector)
576+
async_session = async_sessionmaker(engine, expire_on_commit=False)
597577

598578
async with async_session() as session:
599579
async with session.begin():
600-
embedding = np.array([1, 2, 3])
601-
session.add(Item(id=1, embedding=embedding))
580+
embedding = asyncpg.BitString('101') if engine == asyncpg_engine else '101'
581+
session.add(Item(id=1, binary_embedding=embedding))
602582
item = await session.get(Item, 1)
603-
assert np.array_equal(item.embedding, embedding)
583+
assert item.binary_embedding == embedding
604584

605585
await engine.dispose()
606586

607587
@pytest.mark.asyncio
608-
@pytest.mark.skipif(sqlalchemy_version == 1, reason='Requires SQLAlchemy 2+')
609-
async def test_asyncpg_halfvec(self):
610-
engine = create_async_engine('postgresql+asyncpg://localhost/pgvector_python_test')
588+
async def test_sparsevec(self, engine):
611589
async_session = async_sessionmaker(engine, expire_on_commit=False)
612590

613-
# TODO do not throw error when types are registered
614-
# @event.listens_for(engine.sync_engine, "connect")
615-
# def connect(dbapi_connection, connection_record):
616-
# from pgvector.asyncpg import register_vector
617-
# dbapi_connection.run_async(register_vector)
618-
619591
async with async_session() as session:
620592
async with session.begin():
621593
embedding = [1, 2, 3]
622-
session.add(Item(id=1, half_embedding=embedding))
594+
session.add(Item(id=1, sparse_embedding=embedding))
623595
item = await session.get(Item, 1)
624-
assert item.half_embedding.to_list() == embedding
596+
assert item.sparse_embedding.to_list() == embedding
625597

626598
await engine.dispose()
627599

628600
@pytest.mark.asyncio
629-
@pytest.mark.skipif(sqlalchemy_version == 1, reason='Requires SQLAlchemy 2+')
630-
async def test_asyncpg_bit(self):
631-
import asyncpg
632-
633-
engine = create_async_engine('postgresql+asyncpg://localhost/pgvector_python_test')
601+
async def test_avg(self, engine):
634602
async_session = async_sessionmaker(engine, expire_on_commit=False)
635603

636604
async with async_session() as session:
637605
async with session.begin():
638-
embedding = asyncpg.BitString('101')
639-
session.add(Item(id=1, binary_embedding=embedding))
640-
item = await session.get(Item, 1)
641-
assert item.binary_embedding == embedding
606+
session.add(Item(embedding=[1, 2, 3]))
607+
session.add(Item(embedding=[4, 5, 6]))
608+
avg = await session.scalars(select(func.avg(Item.embedding)))
609+
assert avg.first() == '[2.5,3.5,4.5]'
642610

643611
await engine.dispose()
644612

613+
614+
class TestSqlalchemyAsync2:
615+
def setup_method(self):
616+
delete_items()
617+
645618
@pytest.mark.asyncio
646619
@pytest.mark.skipif(sqlalchemy_version == 1, reason='Requires SQLAlchemy 2+')
647-
async def test_asyncpg_sparsevec(self):
648-
engine = create_async_engine('postgresql+asyncpg://localhost/pgvector_python_test')
620+
async def test_psycopg_async_vector_array(self):
621+
engine = create_async_engine('postgresql+psycopg://localhost/pgvector_python_test')
649622
async_session = async_sessionmaker(engine, expire_on_commit=False)
650623

651-
# TODO do not throw error when types are registered
652-
# @event.listens_for(engine.sync_engine, "connect")
653-
# def connect(dbapi_connection, connection_record):
654-
# from pgvector.asyncpg import register_vector
655-
# dbapi_connection.run_async(register_vector)
624+
@event.listens_for(engine.sync_engine, "connect")
625+
def connect(dbapi_connection, connection_record):
626+
from pgvector.psycopg import register_vector_async
627+
dbapi_connection.run_async(register_vector_async)
656628

657629
async with async_session() as session:
658630
async with session.begin():
659-
embedding = [1, 2, 3]
660-
session.add(Item(id=1, sparse_embedding=embedding))
631+
session.add(Item(id=1, embeddings=[np.array([1, 2, 3]), np.array([4, 5, 6])]))
632+
633+
# this fails if the driver does not cast arrays
661634
item = await session.get(Item, 1)
662-
assert item.sparse_embedding.to_list() == embedding
635+
assert item.embeddings[0].tolist() == [1, 2, 3]
636+
assert item.embeddings[1].tolist() == [4, 5, 6]
663637

664638
await engine.dispose()
665639

0 commit comments

Comments
 (0)