Skip to content

Commit 5c35a53

Browse files
committed
Added test for binary quantization with SQLAlchemy - #98 [skip ci]
1 parent ca637bf commit 5c35a53

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

tests/test_sqlalchemy.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,14 @@ class Item(Base):
6565
)
6666
half_precision_index.create(engine)
6767

68+
binary_quantize_index = Index(
69+
'sqlalchemy_orm_binary_quantize_index',
70+
func.cast(func.binary_quantize(Item.embedding), BIT(3)).label('embedding'),
71+
postgresql_using='hnsw',
72+
postgresql_with={'m': 16, 'ef_construction': 64},
73+
postgresql_ops={'embedding': 'bit_hamming_ops'}
74+
)
75+
binary_quantize_index.create(engine)
6876

6977
def create_items():
7078
session = Session(engine)
@@ -469,6 +477,18 @@ def test_half_precision(self):
469477
items = session.query(Item).order_by(func.cast(Item.embedding, HALFVEC(3)).l2_distance([1, 1, 1])).all()
470478
assert [v.id for v in items] == [1, 3, 2]
471479

480+
def test_binary_quantize(self):
481+
session = Session(engine)
482+
session.add(Item(id=1, embedding=[-1, -2, -3]))
483+
session.add(Item(id=2, embedding=[1, -2, 3]))
484+
session.add(Item(id=3, embedding=[1, 2, 3]))
485+
session.commit()
486+
487+
with Session(engine) as session:
488+
distance = func.cast(func.binary_quantize(Item.embedding), BIT(3)).hamming_distance(func.binary_quantize(func.cast([3, -1, 2], VECTOR(3))))
489+
items = session.query(Item).order_by(distance).all()
490+
assert [v.id for v in items] == [2, 3, 1]
491+
472492
@pytest.mark.asyncio
473493
@pytest.mark.skipif(sqlalchemy_version == 1, reason='Requires SQLAlchemy 2+')
474494
async def test_async(self):

0 commit comments

Comments
 (0)