Skip to content

Commit 5e38160

Browse files
committed
Improved array tests
1 parent f08cec7 commit 5e38160

File tree

1 file changed

+28
-21
lines changed

1 file changed

+28
-21
lines changed

tests/test_sqlalchemy.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,11 @@
2727
session.execute(text('CREATE EXTENSION IF NOT EXISTS vector'))
2828
session.commit()
2929

30-
array_engine = create_engine('postgresql+psycopg2://localhost/pgvector_python_test')
30+
psycopg2_array_engine = create_engine('postgresql+psycopg2://localhost/pgvector_python_test')
31+
array_engines = [psycopg2_array_engine]
3132

3233

33-
@event.listens_for(array_engine, "connect")
34+
@event.listens_for(psycopg2_array_engine, "connect")
3435
def connect(dbapi_connection, connection_record):
3536
from pgvector.psycopg2 import register_vector
3637
register_vector(dbapi_connection, globally=False, arrays=True)
@@ -468,8 +469,31 @@ def test_automap(self, engine):
468469
item = session.query(AutoItem).first()
469470
assert item.embedding.tolist() == [1, 2, 3]
470471

472+
def test_half_precision(self, engine):
473+
create_items()
474+
with Session(engine) as session:
475+
items = session.query(Item).order_by(func.cast(Item.embedding, HALFVEC(3)).l2_distance([1, 1, 1])).all()
476+
assert [v.id for v in items] == [1, 3, 2]
477+
478+
def test_binary_quantize(self, engine):
479+
with Session(engine) as session:
480+
session.add(Item(id=1, embedding=[-1, -2, -3]))
481+
session.add(Item(id=2, embedding=[1, -2, 3]))
482+
session.add(Item(id=3, embedding=[1, 2, 3]))
483+
session.commit()
484+
485+
distance = func.cast(func.binary_quantize(Item.embedding), BIT(3)).hamming_distance(func.binary_quantize(func.cast([3, -1, 2], VECTOR(3))))
486+
items = session.query(Item).order_by(distance).all()
487+
assert [v.id for v in items] == [2, 3, 1]
488+
489+
490+
@pytest.mark.parametrize("engine", array_engines)
491+
class TestSqlalchemyArray:
492+
def setup_method(self):
493+
delete_items()
494+
471495
def test_vector_array(self, engine):
472-
with Session(array_engine) as session:
496+
with Session(engine) as session:
473497
session.add(Item(id=1, embeddings=[np.array([1, 2, 3]), np.array([4, 5, 6])]))
474498
session.commit()
475499

@@ -479,7 +503,7 @@ def test_vector_array(self, engine):
479503
assert item.embeddings[1].tolist() == [4, 5, 6]
480504

481505
def test_halfvec_array(self, engine):
482-
with Session(array_engine) as session:
506+
with Session(engine) as session:
483507
session.add(Item(id=1, half_embeddings=[np.array([1, 2, 3]), np.array([4, 5, 6])]))
484508
session.commit()
485509

@@ -488,23 +512,6 @@ def test_halfvec_array(self, engine):
488512
assert item.half_embeddings[0].to_list() == [1, 2, 3]
489513
assert item.half_embeddings[1].to_list() == [4, 5, 6]
490514

491-
def test_half_precision(self, engine):
492-
create_items()
493-
with Session(engine) as session:
494-
items = session.query(Item).order_by(func.cast(Item.embedding, HALFVEC(3)).l2_distance([1, 1, 1])).all()
495-
assert [v.id for v in items] == [1, 3, 2]
496-
497-
def test_binary_quantize(self, engine):
498-
with Session(engine) as session:
499-
session.add(Item(id=1, embedding=[-1, -2, -3]))
500-
session.add(Item(id=2, embedding=[1, -2, 3]))
501-
session.add(Item(id=3, embedding=[1, 2, 3]))
502-
session.commit()
503-
504-
distance = func.cast(func.binary_quantize(Item.embedding), BIT(3)).hamming_distance(func.binary_quantize(func.cast([3, -1, 2], VECTOR(3))))
505-
items = session.query(Item).order_by(distance).all()
506-
assert [v.id for v in items] == [2, 3, 1]
507-
508515

509516
class TestSqlalchemyAsync:
510517
def setup_method(self):

0 commit comments

Comments
 (0)