Skip to content

Commit 04aa5bc

Browse files
committed
Added test for arrays with SQLAlchemy async - #101
1 parent d23844e commit 04aa5bc

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

tests/test_sqlalchemy.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,3 +483,25 @@ async def test_async(self):
483483
assert avg.first() == '[2.5,3.5,4.5]'
484484

485485
await engine.dispose()
486+
487+
@pytest.mark.asyncio
488+
@pytest.mark.skipif(sqlalchemy_version == 1, reason='Requires SQLAlchemy 2+')
489+
async def test_async_vector_array(self):
490+
engine = create_async_engine('postgresql+psycopg://localhost/pgvector_python_test')
491+
async_session = async_sessionmaker(engine, expire_on_commit=False)
492+
493+
@event.listens_for(engine.sync_engine, "connect")
494+
def connect(dbapi_connection, connection_record):
495+
from pgvector.psycopg import register_vector_async
496+
dbapi_connection.run_async(register_vector_async)
497+
498+
async with async_session() as session:
499+
async with session.begin():
500+
session.add(Item(id=1, embeddings=[np.array([1, 2, 3]), np.array([4, 5, 6])]))
501+
502+
# this fails if the driver does not cast arrays
503+
item = await session.get(Item, 1)
504+
assert item.embeddings[0].tolist() == [1, 2, 3]
505+
assert item.embeddings[1].tolist() == [4, 5, 6]
506+
507+
await engine.dispose()

0 commit comments

Comments
 (0)