Skip to content

Commit 664b8ee

Browse files
committed
Added test for halfvec arrays with SQLAlchemy - #101
1 parent ea32504 commit 664b8ee

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

tests/test_sqlalchemy.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class Item(Base):
3232
binary_embedding = mapped_column(BIT(3))
3333
sparse_embedding = mapped_column(SPARSEVEC(3))
3434
embeddings = mapped_column(ARRAY(VECTOR(3)))
35+
half_embeddings = mapped_column(ARRAY(HALFVEC(3)))
3536

3637

3738
Base.metadata.drop_all(engine)
@@ -447,6 +448,20 @@ def test_vector_array(self):
447448
assert item.embeddings[0].tolist() == [1, 2, 3]
448449
assert item.embeddings[1].tolist() == [4, 5, 6]
449450

451+
def test_halfvec_array(self):
452+
session = Session(engine)
453+
session.add(Item(id=1, half_embeddings=[np.array([1, 2, 3]), np.array([4, 5, 6])]))
454+
session.commit()
455+
456+
with engine.connect() as connection:
457+
from pgvector.psycopg2 import register_vector
458+
register_vector(connection.connection.dbapi_connection, globally=False, arrays=True)
459+
460+
# this fails if the driver does not cast arrays
461+
item = Session(bind=connection).get(Item, 1)
462+
assert item.half_embeddings[0].to_list() == [1, 2, 3]
463+
assert item.half_embeddings[1].to_list() == [4, 5, 6]
464+
450465
def test_half_precision(self):
451466
create_items()
452467
with Session(engine) as session:

0 commit comments

Comments
 (0)