Skip to content

Commit a1d8997

Browse files
committed
Improved test code [skip ci]
1 parent 224c18a commit a1d8997

File tree

1 file changed

+14
-16
lines changed

1 file changed

+14
-16
lines changed

tests/test_sqlalchemy.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,17 @@
1919

2020
psycopg2_engine = create_engine('postgresql+psycopg2://localhost/pgvector_python_test')
2121
pg8000_engine = create_engine(f'postgresql+pg8000://{os.environ["USER"]}@localhost/pgvector_python_test')
22+
psycopg2_array_engine = create_engine('postgresql+psycopg2://localhost/pgvector_python_test')
23+
24+
25+
@event.listens_for(psycopg2_array_engine, "connect")
26+
def psycopg2_connect(dbapi_connection, connection_record):
27+
from pgvector.psycopg2 import register_vector
28+
register_vector(dbapi_connection, globally=False, arrays=True)
29+
30+
2231
engines = [psycopg2_engine, pg8000_engine]
32+
array_engines = [psycopg2_array_engine]
2333
async_engines = []
2434

2535
if sqlalchemy_version > 1:
@@ -32,22 +42,6 @@
3242
asyncpg_engine = create_async_engine('postgresql+asyncpg://localhost/pgvector_python_test')
3343
async_engines.append(asyncpg_engine)
3444

35-
setup_engine = engines[0]
36-
with Session(setup_engine) as session:
37-
session.execute(text('CREATE EXTENSION IF NOT EXISTS vector'))
38-
session.commit()
39-
40-
psycopg2_array_engine = create_engine('postgresql+psycopg2://localhost/pgvector_python_test')
41-
array_engines = [psycopg2_array_engine]
42-
43-
44-
@event.listens_for(psycopg2_array_engine, "connect")
45-
def psycopg2_connect(dbapi_connection, connection_record):
46-
from pgvector.psycopg2 import register_vector
47-
register_vector(dbapi_connection, globally=False, arrays=True)
48-
49-
50-
if sqlalchemy_version > 1:
5145
psycopg_array_engine = create_engine('postgresql+psycopg://localhost/pgvector_python_test')
5246
array_engines.append(psycopg_array_engine)
5347

@@ -56,6 +50,10 @@ def psycopg_connect(dbapi_connection, connection_record):
5650
from pgvector.psycopg import register_vector
5751
register_vector(dbapi_connection)
5852

53+
setup_engine = engines[0]
54+
with Session(setup_engine) as session:
55+
session.execute(text('CREATE EXTENSION IF NOT EXISTS vector'))
56+
session.commit()
5957

6058
Base = declarative_base()
6159

0 commit comments

Comments
 (0)