Skip to content

Commit 2883156

Browse files
committed
Improved tests for async SQLAlchemy engines [skip ci]
1 parent 651df08 commit 2883156

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

tests/test_sqlalchemy.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,18 @@
1919
psycopg2_engine = create_engine('postgresql+psycopg2://localhost/pgvector_python_test')
2020
pg8000_engine = create_engine(f'postgresql+pg8000://{os.environ["USER"]}@localhost/pgvector_python_test')
2121
engines = [psycopg2_engine, pg8000_engine]
22+
async_engines = []
2223

2324
if sqlalchemy_version > 1:
2425
psycopg_engine = create_engine('postgresql+psycopg://localhost/pgvector_python_test')
2526
engines.append(psycopg_engine)
2627

28+
psycopg_async_engine = create_async_engine('postgresql+psycopg://localhost/pgvector_python_test')
29+
async_engines.append(psycopg_async_engine)
30+
31+
asyncpg_engine = create_async_engine('postgresql+asyncpg://localhost/pgvector_python_test')
32+
async_engines.append(asyncpg_engine)
33+
2734
setup_engine = engines[0]
2835
with Session(setup_engine) as session:
2936
session.execute(text('CREATE EXTENSION IF NOT EXISTS vector'))
@@ -531,14 +538,13 @@ def test_halfvec_array(self, engine):
531538
assert item.half_embeddings[1].to_list() == [4, 5, 6]
532539

533540

541+
@pytest.mark.parametrize('engine', async_engines)
534542
class TestSqlalchemyAsync:
535543
def setup_method(self):
536544
delete_items()
537545

538546
@pytest.mark.asyncio
539-
@pytest.mark.skipif(sqlalchemy_version == 1, reason='Requires SQLAlchemy 2+')
540-
async def test_psycopg_async_avg(self):
541-
engine = create_async_engine('postgresql+psycopg://localhost/pgvector_python_test')
547+
async def test_psycopg_async_avg(self, engine):
542548
async_session = async_sessionmaker(engine, expire_on_commit=False)
543549

544550
async with async_session() as session:
@@ -550,6 +556,11 @@ async def test_psycopg_async_avg(self):
550556

551557
await engine.dispose()
552558

559+
560+
class TestSqlalchemyAsync2:
561+
def setup_method(self):
562+
delete_items()
563+
553564
@pytest.mark.asyncio
554565
@pytest.mark.skipif(sqlalchemy_version == 1, reason='Requires SQLAlchemy 2+')
555566
async def test_psycopg_async_vector_array(self):

0 commit comments

Comments
 (0)