11import asyncpg
22import numpy as np
33import os
4- from pgvector import SparseVector
4+ from pgvector import HalfVector , SparseVector
55from pgvector .sqlalchemy import VECTOR , HALFVEC , BIT , SPARSEVEC , avg , sum
66import pytest
77from sqlalchemy import create_engine , event , insert , inspect , select , text , MetaData , Table , Column , Index , Integer , ARRAY
@@ -256,7 +256,7 @@ def test_halfvec(self, engine):
256256 session .add (Item (id = 1 , half_embedding = [1 , 2 , 3 ]))
257257 session .commit ()
258258 item = session .get (Item , 1 )
259- assert item .half_embedding . to_list () == [1 , 2 , 3 ]
259+ assert item .half_embedding == HalfVector ( [1 , 2 , 3 ])
260260
261261 def test_halfvec_l2_distance (self , engine ):
262262 create_items ()
@@ -348,7 +348,7 @@ def test_sparsevec(self, engine):
348348 session .add (Item (id = 1 , sparse_embedding = [1 , 2 , 3 ]))
349349 session .commit ()
350350 item = session .get (Item , 1 )
351- assert item .sparse_embedding . to_list () == [1 , 2 , 3 ]
351+ assert item .sparse_embedding == SparseVector ( [1 , 2 , 3 ])
352352
353353 def test_sparsevec_l2_distance (self , engine ):
354354 create_items ()
@@ -551,8 +551,8 @@ def test_halfvec_array(self, engine):
551551
552552 # this fails if the driver does not cast arrays
553553 item = session .get (Item , 1 )
554- assert item .half_embeddings [0 ]. to_list () == [1 , 2 , 3 ]
555- assert item .half_embeddings [1 ]. to_list () == [4 , 5 , 6 ]
554+ assert item .half_embeddings [0 ] == HalfVector ( [1 , 2 , 3 ])
555+ assert item .half_embeddings [1 ] == HalfVector ( [4 , 5 , 6 ])
556556
557557
558558@pytest .mark .parametrize ('engine' , async_engines )
@@ -582,7 +582,7 @@ async def test_halfvec(self, engine):
582582 embedding = [1 , 2 , 3 ]
583583 session .add (Item (id = 1 , half_embedding = embedding ))
584584 item = await session .get (Item , 1 )
585- assert item .half_embedding . to_list () == embedding
585+ assert item .half_embedding == HalfVector ( embedding )
586586
587587 await engine .dispose ()
588588
@@ -608,7 +608,7 @@ async def test_sparsevec(self, engine):
608608 embedding = [1 , 2 , 3 ]
609609 session .add (Item (id = 1 , sparse_embedding = embedding ))
610610 item = await session .get (Item , 1 )
611- assert item .sparse_embedding . to_list () == embedding
611+ assert item .sparse_embedding == SparseVector ( embedding )
612612
613613 await engine .dispose ()
614614
0 commit comments