|
19 | 19 |
|
20 | 20 | psycopg2_engine = create_engine('postgresql+psycopg2://localhost/pgvector_python_test') |
21 | 21 | pg8000_engine = create_engine(f'postgresql+pg8000://{os.environ["USER"]}@localhost/pgvector_python_test') |
22 | | -psycopg2_array_engine = create_engine('postgresql+psycopg2://localhost/pgvector_python_test') |
| 22 | +psycopg2_type_engine = create_engine('postgresql+psycopg2://localhost/pgvector_python_test') |
23 | 23 |
|
24 | 24 |
|
25 | | -@event.listens_for(psycopg2_array_engine, "connect") |
| 25 | +@event.listens_for(psycopg2_type_engine, "connect") |
26 | 26 | def psycopg2_connect(dbapi_connection, connection_record): |
27 | 27 | from pgvector.psycopg2 import register_vector |
28 | 28 | register_vector(dbapi_connection, globally=False, arrays=True) |
29 | 29 |
|
30 | 30 |
|
31 | | -engines = [psycopg2_engine, pg8000_engine] |
32 | | -array_engines = [psycopg2_array_engine] |
| 31 | +engines = [psycopg2_engine, pg8000_engine, psycopg2_type_engine] |
| 32 | +array_engines = [psycopg2_type_engine] |
33 | 33 | async_engines = [] |
34 | 34 |
|
35 | 35 | if sqlalchemy_version > 1: |
36 | 36 | psycopg_engine = create_engine('postgresql+psycopg://localhost/pgvector_python_test') |
37 | 37 | engines.append(psycopg_engine) |
38 | 38 |
|
| 39 | + psycopg_type_engine = create_engine('postgresql+psycopg://localhost/pgvector_python_test') |
| 40 | + |
| 41 | + @event.listens_for(psycopg_type_engine, "connect") |
| 42 | + def psycopg_connect(dbapi_connection, connection_record): |
| 43 | + from pgvector.psycopg import register_vector |
| 44 | + register_vector(dbapi_connection) |
| 45 | + |
| 46 | + engines.append(psycopg_type_engine) |
| 47 | + array_engines.append(psycopg_type_engine) |
| 48 | + |
39 | 49 | psycopg_async_engine = create_async_engine('postgresql+psycopg://localhost/pgvector_python_test') |
40 | 50 | async_engines.append(psycopg_async_engine) |
41 | 51 |
|
42 | 52 | asyncpg_engine = create_async_engine('postgresql+asyncpg://localhost/pgvector_python_test') |
43 | 53 | async_engines.append(asyncpg_engine) |
44 | 54 |
|
45 | | - psycopg_array_engine = create_engine('postgresql+psycopg://localhost/pgvector_python_test') |
46 | | - array_engines.append(psycopg_array_engine) |
47 | | - |
48 | | - @event.listens_for(psycopg_array_engine, "connect") |
49 | | - def psycopg_connect(dbapi_connection, connection_record): |
50 | | - from pgvector.psycopg import register_vector |
51 | | - register_vector(dbapi_connection) |
52 | | - |
53 | 55 | setup_engine = engines[0] |
54 | 56 | with Session(setup_engine) as session: |
55 | 57 | session.execute(text('CREATE EXTENSION IF NOT EXISTS vector')) |
@@ -169,9 +171,10 @@ def test_orm(self, engine): |
169 | 171 | stmt = select(Item) |
170 | 172 | with Session(engine) as session: |
171 | 173 | items = [v[0] for v in session.execute(stmt).all()] |
172 | | - assert items[0].id in [1, 4, 7] |
173 | | - assert items[1].id in [2, 5, 8] |
174 | | - assert items[2].id in [3, 6, 9] |
| 174 | + # TODO improve |
| 175 | + assert items[0].id % 3 == 1 |
| 176 | + assert items[1].id % 3 == 2 |
| 177 | + assert items[2].id % 3 == 0 |
175 | 178 | assert np.array_equal(items[0].embedding, np.array([1.5, 2, 3])) |
176 | 179 | assert items[0].embedding.dtype == np.float32 |
177 | 180 | assert np.array_equal(items[1].embedding, np.array([4, 5, 6])) |
|
0 commit comments