@@ -31,6 +31,7 @@ def psycopg2_connect(dbapi_connection, connection_record):
3131engines = [psycopg2_engine , pg8000_engine , psycopg2_type_engine ]
3232array_engines = [psycopg2_type_engine ]
3333async_engines = []
34+ async_array_engines = []
3435
3536if sqlalchemy_version > 1 :
3637 psycopg_engine = create_engine ('postgresql+psycopg://localhost/pgvector_python_test' )
@@ -46,11 +47,32 @@ def psycopg_connect(dbapi_connection, connection_record):
4647 engines .append (psycopg_type_engine )
4748 array_engines .append (psycopg_type_engine )
4849
50+ psycopg_async_type_engine = create_async_engine ('postgresql+psycopg://localhost/pgvector_python_test' )
51+
52+ @event .listens_for (psycopg_async_type_engine .sync_engine , "connect" )
53+ def connect (dbapi_connection , connection_record ):
54+ from pgvector .psycopg import register_vector_async
55+ dbapi_connection .run_async (register_vector_async )
56+
57+ async_engines .append (psycopg_async_type_engine )
58+ async_array_engines .append (psycopg_async_type_engine )
59+
4960 psycopg_async_engine = create_async_engine ('postgresql+psycopg://localhost/pgvector_python_test' )
5061 async_engines .append (psycopg_async_engine )
5162
5263 asyncpg_engine = create_async_engine ('postgresql+asyncpg://localhost/pgvector_python_test' )
5364 async_engines .append (asyncpg_engine )
65+ async_array_engines .append (asyncpg_engine )
66+
67+ asyncpg_type_engine = create_async_engine ('postgresql+asyncpg://localhost/pgvector_python_test' )
68+
69+ @event .listens_for (asyncpg_type_engine .sync_engine , "connect" )
70+ def connect (dbapi_connection , connection_record ):
71+ from pgvector .asyncpg import register_vector
72+ dbapi_connection .run_async (register_vector )
73+
74+ # TODO do not throw error when types are registered
75+ # async_array_engines.append(asyncpg_type_engine)
5476
5577setup_engine = engines [0 ]
5678with Session (setup_engine ) as session :
@@ -599,6 +621,10 @@ async def test_sparsevec(self, engine):
599621
600622 @pytest .mark .asyncio
601623 async def test_avg (self , engine ):
624+ # TODO do not throw error when types are registered
625+ if engine == psycopg_async_type_engine :
626+ return
627+
602628 async_session = async_sessionmaker (engine , expire_on_commit = False )
603629
604630 async with async_session () as session :
@@ -611,43 +637,15 @@ async def test_avg(self, engine):
611637 await engine .dispose ()
612638
613639
614- @pytest .mark .skipif ( sqlalchemy_version == 1 , reason = 'Requires SQLAlchemy 2+' )
640+ @pytest .mark .parametrize ( 'engine' , async_array_engines )
615641class TestSqlalchemyAsyncArray :
616642 def setup_method (self ):
617643 delete_items ()
618644
619645 @pytest .mark .asyncio
620- async def test_psycopg_vector_array (self ):
621- engine = create_async_engine ('postgresql+psycopg://localhost/pgvector_python_test' )
646+ async def test_vector_array (self , engine ):
622647 async_session = async_sessionmaker (engine , expire_on_commit = False )
623648
624- @event .listens_for (engine .sync_engine , "connect" )
625- def connect (dbapi_connection , connection_record ):
626- from pgvector .psycopg import register_vector_async
627- dbapi_connection .run_async (register_vector_async )
628-
629- async with async_session () as session :
630- async with session .begin ():
631- session .add (Item (id = 1 , embeddings = [np .array ([1 , 2 , 3 ]), np .array ([4 , 5 , 6 ])]))
632-
633- # this fails if the driver does not cast arrays
634- item = await session .get (Item , 1 )
635- assert item .embeddings [0 ].tolist () == [1 , 2 , 3 ]
636- assert item .embeddings [1 ].tolist () == [4 , 5 , 6 ]
637-
638- await engine .dispose ()
639-
640- @pytest .mark .asyncio
641- async def test_asyncpg_vector_array (self ):
642- engine = create_async_engine ('postgresql+asyncpg://localhost/pgvector_python_test' )
643- async_session = async_sessionmaker (engine , expire_on_commit = False )
644-
645- # TODO do not throw error when types are registered
646- # @event.listens_for(engine.sync_engine, "connect")
647- # def connect(dbapi_connection, connection_record):
648- # from pgvector.asyncpg import register_vector
649- # dbapi_connection.run_async(register_vector)
650-
651649 async with async_session () as session :
652650 async with session .begin ():
653651 session .add (Item (id = 1 , embeddings = [np .array ([1 , 2 , 3 ]), np .array ([4 , 5 , 6 ])]))
0 commit comments