Skip to content

Commit bb02ee2

Browse files
committed
Improved tests [skip ci]
1 parent 022dd06 commit bb02ee2

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

tests/test_sqlmodel.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
from pgvector import SparseVector
2+
from pgvector import HalfVector, SparseVector
33
from pgvector.sqlalchemy import VECTOR, HALFVEC, BIT, SPARSEVEC, avg, sum
44
import pytest
55
from sqlalchemy.exc import StatementError
@@ -107,7 +107,7 @@ def test_halfvec(self):
107107
session.add(Item(id=1, half_embedding=[1, 2, 3]))
108108
session.commit()
109109
item = session.get(Item, 1)
110-
assert item.half_embedding.to_list() == [1, 2, 3]
110+
assert item.half_embedding == HalfVector([1, 2, 3])
111111

112112
def test_halfvec_l2_distance(self):
113113
create_items()
@@ -157,7 +157,7 @@ def test_sparsevec(self):
157157
session.add(Item(id=1, sparse_embedding=[1, 2, 3]))
158158
session.commit()
159159
item = session.get(Item, 1)
160-
assert item.sparse_embedding.to_list() == [1, 2, 3]
160+
assert item.sparse_embedding == SparseVector([1, 2, 3])
161161

162162
def test_sparsevec_l2_distance(self):
163163
create_items()
@@ -220,7 +220,7 @@ def test_halfvec_avg(self):
220220
session.add(Item(half_embedding=[1, 2, 3]))
221221
session.add(Item(half_embedding=[4, 5, 6]))
222222
res = session.exec(select(avg(Item.half_embedding))).first()
223-
assert res.to_list() == [2.5, 3.5, 4.5]
223+
assert res == HalfVector([2.5, 3.5, 4.5])
224224

225225
def test_halfvec_sum(self):
226226
with Session(engine) as session:
@@ -229,7 +229,7 @@ def test_halfvec_sum(self):
229229
session.add(Item(half_embedding=[1, 2, 3]))
230230
session.add(Item(half_embedding=[4, 5, 6]))
231231
res = session.exec(select(sum(Item.half_embedding))).first()
232-
assert res.to_list() == [5, 7, 9]
232+
assert res == HalfVector([5, 7, 9])
233233

234234
def test_bad_dimensions(self):
235235
item = Item(embedding=[1, 2])

0 commit comments

Comments
 (0)