@@ -544,122 +544,96 @@ def setup_method(self):
544544 delete_items ()
545545
546546 @pytest .mark .asyncio
547- async def test_psycopg_async_avg (self , engine ):
547+ async def test_vector (self , engine ):
548548 async_session = async_sessionmaker (engine , expire_on_commit = False )
549549
550550 async with async_session () as session :
551551 async with session .begin ():
552- session . add ( Item ( embedding = [1 , 2 , 3 ]) )
553- session .add (Item (embedding = [ 4 , 5 , 6 ] ))
554- avg = await session .scalars ( select ( func . avg ( Item . embedding )) )
555- assert avg . first () == '[2.5,3.5,4.5]'
552+ embedding = np . array ( [1 , 2 , 3 ])
553+ session .add (Item (id = 1 , embedding = embedding ))
554+ item = await session .get ( Item , 1 )
555+ assert np . array_equal ( item . embedding , embedding )
556556
557557 await engine .dispose ()
558558
559-
560- class TestSqlalchemyAsync2 :
561- def setup_method (self ):
562- delete_items ()
563-
564559 @pytest .mark .asyncio
565- @pytest .mark .skipif (sqlalchemy_version == 1 , reason = 'Requires SQLAlchemy 2+' )
566- async def test_psycopg_async_vector_array (self ):
567- engine = create_async_engine ('postgresql+psycopg://localhost/pgvector_python_test' )
560+ async def test_halfvec (self , engine ):
568561 async_session = async_sessionmaker (engine , expire_on_commit = False )
569562
570- @event .listens_for (engine .sync_engine , "connect" )
571- def connect (dbapi_connection , connection_record ):
572- from pgvector .psycopg import register_vector_async
573- dbapi_connection .run_async (register_vector_async )
574-
575563 async with async_session () as session :
576564 async with session .begin ():
577- session .add (Item (id = 1 , embeddings = [np .array ([1 , 2 , 3 ]), np .array ([4 , 5 , 6 ])]))
578-
579- # this fails if the driver does not cast arrays
565+ embedding = [1 , 2 , 3 ]
566+ session .add (Item (id = 1 , half_embedding = embedding ))
580567 item = await session .get (Item , 1 )
581- assert item .embeddings [0 ].tolist () == [1 , 2 , 3 ]
582- assert item .embeddings [1 ].tolist () == [4 , 5 , 6 ]
568+ assert item .half_embedding .to_list () == embedding
583569
584570 await engine .dispose ()
585571
586572 @pytest .mark .asyncio
587- @pytest .mark .skipif (sqlalchemy_version == 1 , reason = 'Requires SQLAlchemy 2+' )
588- async def test_asyncpg_vector (self ):
589- engine = create_async_engine ('postgresql+asyncpg://localhost/pgvector_python_test' )
590- async_session = async_sessionmaker (engine , expire_on_commit = False )
573+ async def test_bit (self , engine ):
574+ import asyncpg
591575
592- # TODO do not throw error when types are registered
593- # @event.listens_for(engine.sync_engine, "connect")
594- # def connect(dbapi_connection, connection_record):
595- # from pgvector.asyncpg import register_vector
596- # dbapi_connection.run_async(register_vector)
576+ async_session = async_sessionmaker (engine , expire_on_commit = False )
597577
598578 async with async_session () as session :
599579 async with session .begin ():
600- embedding = np . array ([ 1 , 2 , 3 ])
601- session .add (Item (id = 1 , embedding = embedding ))
580+ embedding = asyncpg . BitString ( '101' ) if engine == asyncpg_engine else '101'
581+ session .add (Item (id = 1 , binary_embedding = embedding ))
602582 item = await session .get (Item , 1 )
603- assert np . array_equal ( item .embedding , embedding )
583+ assert item .binary_embedding == embedding
604584
605585 await engine .dispose ()
606586
607587 @pytest .mark .asyncio
608- @pytest .mark .skipif (sqlalchemy_version == 1 , reason = 'Requires SQLAlchemy 2+' )
609- async def test_asyncpg_halfvec (self ):
610- engine = create_async_engine ('postgresql+asyncpg://localhost/pgvector_python_test' )
588+ async def test_sparsevec (self , engine ):
611589 async_session = async_sessionmaker (engine , expire_on_commit = False )
612590
613- # TODO do not throw error when types are registered
614- # @event.listens_for(engine.sync_engine, "connect")
615- # def connect(dbapi_connection, connection_record):
616- # from pgvector.asyncpg import register_vector
617- # dbapi_connection.run_async(register_vector)
618-
619591 async with async_session () as session :
620592 async with session .begin ():
621593 embedding = [1 , 2 , 3 ]
622- session .add (Item (id = 1 , half_embedding = embedding ))
594+ session .add (Item (id = 1 , sparse_embedding = embedding ))
623595 item = await session .get (Item , 1 )
624- assert item .half_embedding .to_list () == embedding
596+ assert item .sparse_embedding .to_list () == embedding
625597
626598 await engine .dispose ()
627599
628600 @pytest .mark .asyncio
629- @pytest .mark .skipif (sqlalchemy_version == 1 , reason = 'Requires SQLAlchemy 2+' )
630- async def test_asyncpg_bit (self ):
631- import asyncpg
632-
633- engine = create_async_engine ('postgresql+asyncpg://localhost/pgvector_python_test' )
601+ async def test_avg (self , engine ):
634602 async_session = async_sessionmaker (engine , expire_on_commit = False )
635603
636604 async with async_session () as session :
637605 async with session .begin ():
638- embedding = asyncpg . BitString ( '101' )
639- session .add (Item (id = 1 , binary_embedding = embedding ))
640- item = await session .get ( Item , 1 )
641- assert item . binary_embedding == embedding
606+ session . add ( Item ( embedding = [ 1 , 2 , 3 ]) )
607+ session .add (Item (embedding = [ 4 , 5 , 6 ] ))
608+ avg = await session .scalars ( select ( func . avg ( Item . embedding )) )
609+ assert avg . first () == '[2.5,3.5,4.5]'
642610
643611 await engine .dispose ()
644612
613+
614+ class TestSqlalchemyAsync2 :
615+ def setup_method (self ):
616+ delete_items ()
617+
645618 @pytest .mark .asyncio
646619 @pytest .mark .skipif (sqlalchemy_version == 1 , reason = 'Requires SQLAlchemy 2+' )
647- async def test_asyncpg_sparsevec (self ):
648- engine = create_async_engine ('postgresql+asyncpg ://localhost/pgvector_python_test' )
620+ async def test_psycopg_async_vector_array (self ):
621+ engine = create_async_engine ('postgresql+psycopg ://localhost/pgvector_python_test' )
649622 async_session = async_sessionmaker (engine , expire_on_commit = False )
650623
651- # TODO do not throw error when types are registered
652- # @event.listens_for(engine.sync_engine, "connect")
653- # def connect(dbapi_connection, connection_record):
654- # from pgvector.asyncpg import register_vector
655- # dbapi_connection.run_async(register_vector)
624+ @event .listens_for (engine .sync_engine , "connect" )
625+ def connect (dbapi_connection , connection_record ):
626+ from pgvector .psycopg import register_vector_async
627+ dbapi_connection .run_async (register_vector_async )
656628
657629 async with async_session () as session :
658630 async with session .begin ():
659- embedding = [1 , 2 , 3 ]
660- session .add (Item (id = 1 , sparse_embedding = embedding ))
631+ session .add (Item (id = 1 , embeddings = [np .array ([1 , 2 , 3 ]), np .array ([4 , 5 , 6 ])]))
632+
633+ # this fails if the driver does not cast arrays
661634 item = await session .get (Item , 1 )
662- assert item .sparse_embedding .to_list () == embedding
635+ assert item .embeddings [0 ].tolist () == [1 , 2 , 3 ]
636+ assert item .embeddings [1 ].tolist () == [4 , 5 , 6 ]
663637
664638 await engine .dispose ()
665639
0 commit comments