Skip to content

Commit 1037d7e

Browse files
committed
Added test for half-precision indexing with SQLAlchemy - #98
1 parent d24839b commit 1037d7e

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

tests/test_sqlalchemy.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,15 @@ class Item(Base):
4646
)
4747
index.create(engine)
4848

49+
half_precision_index = Index(
50+
'sqlalchemy_orm_half_precision_index',
51+
func.cast(Item.embedding, HALFVEC(3)).label('embedding'),
52+
postgresql_using='hnsw',
53+
postgresql_with={'m': 16, 'ef_construction': 64},
54+
postgresql_ops={'embedding': 'halfvec_l2_ops'}
55+
)
56+
half_precision_index.create(engine)
57+
4958

5059
def create_items():
5160
session = Session(engine)
@@ -438,6 +447,12 @@ def test_vector_array(self):
438447
assert item.embeddings[0].tolist() == [1, 2, 3]
439448
assert item.embeddings[1].tolist() == [4, 5, 6]
440449

450+
def test_half_precision(self):
451+
create_items()
452+
with Session(engine) as session:
453+
items = session.query(Item).order_by(func.cast(Item.embedding, HALFVEC(3)).l2_distance([1, 1, 1])).all()
454+
assert [v.id for v in items] == [1, 3, 2]
455+
441456
@pytest.mark.asyncio
442457
@pytest.mark.skipif(sqlalchemy_version == 1, reason='Requires SQLAlchemy 2+')
443458
async def test_async(self):

0 commit comments

Comments
 (0)