Skip to content

Commit 0a76066

Browse files
committed
Use connection from session in example and tests
1 parent 1c7e6a5 commit 0a76066

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ And register the types with the underlying driver
285285
```python
286286
from pgvector.psycopg2 import register_vector
287287

288-
with engine.connect() as connection:
288+
with session.connection() as connection:
289289
register_vector(connection.connection.dbapi_connection, globally=True, arrays=True)
290290
```
291291

tests/test_sqlalchemy.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -439,12 +439,12 @@ def test_vector_array(self):
439439
session.add(Item(id=1, embeddings=[np.array([1, 2, 3]), np.array([4, 5, 6])]))
440440
session.commit()
441441

442-
with engine.connect() as connection:
442+
with session.connection() as connection:
443443
from pgvector.psycopg2 import register_vector
444444
register_vector(connection.connection.dbapi_connection, globally=False, arrays=True)
445445

446446
# this fails if the driver does not cast arrays
447-
item = Session(bind=connection).get(Item, 1)
447+
item = session.get(Item, 1)
448448
assert item.embeddings[0].tolist() == [1, 2, 3]
449449
assert item.embeddings[1].tolist() == [4, 5, 6]
450450

@@ -453,12 +453,12 @@ def test_halfvec_array(self):
453453
session.add(Item(id=1, half_embeddings=[np.array([1, 2, 3]), np.array([4, 5, 6])]))
454454
session.commit()
455455

456-
with engine.connect() as connection:
456+
with session.connection() as connection:
457457
from pgvector.psycopg2 import register_vector
458458
register_vector(connection.connection.dbapi_connection, globally=False, arrays=True)
459459

460460
# this fails if the driver does not cast arrays
461-
item = Session(bind=connection).get(Item, 1)
461+
item = session.get(Item, 1)
462462
assert item.half_embeddings[0].to_list() == [1, 2, 3]
463463
assert item.half_embeddings[1].to_list() == [4, 5, 6]
464464

0 commit comments

Comments
 (0)