11import numpy as np
2- from pgvector import SparseVector
2+ from pgvector import HalfVector , SparseVector
33from pgvector .sqlalchemy import VECTOR , HALFVEC , BIT , SPARSEVEC , avg , sum
44import pytest
55from sqlalchemy .exc import StatementError
@@ -107,7 +107,7 @@ def test_halfvec(self):
107107 session .add (Item (id = 1 , half_embedding = [1 , 2 , 3 ]))
108108 session .commit ()
109109 item = session .get (Item , 1 )
110- assert item .half_embedding . to_list () == [1 , 2 , 3 ]
110+ assert item .half_embedding == HalfVector ( [1 , 2 , 3 ])
111111
112112 def test_halfvec_l2_distance (self ):
113113 create_items ()
@@ -157,7 +157,7 @@ def test_sparsevec(self):
157157 session .add (Item (id = 1 , sparse_embedding = [1 , 2 , 3 ]))
158158 session .commit ()
159159 item = session .get (Item , 1 )
160- assert item .sparse_embedding . to_list () == [1 , 2 , 3 ]
160+ assert item .sparse_embedding == SparseVector ( [1 , 2 , 3 ])
161161
162162 def test_sparsevec_l2_distance (self ):
163163 create_items ()
@@ -220,7 +220,7 @@ def test_halfvec_avg(self):
220220 session .add (Item (half_embedding = [1 , 2 , 3 ]))
221221 session .add (Item (half_embedding = [4 , 5 , 6 ]))
222222 res = session .exec (select (avg (Item .half_embedding ))).first ()
223- assert res . to_list () == [2.5 , 3.5 , 4.5 ]
223+ assert res == HalfVector ( [2.5 , 3.5 , 4.5 ])
224224
225225 def test_halfvec_sum (self ):
226226 with Session (engine ) as session :
@@ -229,7 +229,7 @@ def test_halfvec_sum(self):
229229 session .add (Item (half_embedding = [1 , 2 , 3 ]))
230230 session .add (Item (half_embedding = [4 , 5 , 6 ]))
231231 res = session .exec (select (sum (Item .half_embedding ))).first ()
232- assert res . to_list () == [5 , 7 , 9 ]
232+ assert res == HalfVector ( [5 , 7 , 9 ])
233233
234234 def test_bad_dimensions (self ):
235235 item = Item (embedding = [1 , 2 ])
0 commit comments