Skip to content

Commit 6b8857a

Browse files
committed
Improved tests [skip ci]
1 parent 8441b46 commit 6b8857a

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

tests/test_sqlalchemy.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncpg
22
import numpy as np
33
import os
4-
from pgvector import SparseVector
4+
from pgvector import HalfVector, SparseVector
55
from pgvector.sqlalchemy import VECTOR, HALFVEC, BIT, SPARSEVEC, avg, sum
66
import pytest
77
from sqlalchemy import create_engine, event, insert, inspect, select, text, MetaData, Table, Column, Index, Integer, ARRAY
@@ -256,7 +256,7 @@ def test_halfvec(self, engine):
256256
session.add(Item(id=1, half_embedding=[1, 2, 3]))
257257
session.commit()
258258
item = session.get(Item, 1)
259-
assert item.half_embedding.to_list() == [1, 2, 3]
259+
assert item.half_embedding == HalfVector([1, 2, 3])
260260

261261
def test_halfvec_l2_distance(self, engine):
262262
create_items()
@@ -348,7 +348,7 @@ def test_sparsevec(self, engine):
348348
session.add(Item(id=1, sparse_embedding=[1, 2, 3]))
349349
session.commit()
350350
item = session.get(Item, 1)
351-
assert item.sparse_embedding.to_list() == [1, 2, 3]
351+
assert item.sparse_embedding == SparseVector([1, 2, 3])
352352

353353
def test_sparsevec_l2_distance(self, engine):
354354
create_items()
@@ -551,8 +551,8 @@ def test_halfvec_array(self, engine):
551551

552552
# this fails if the driver does not cast arrays
553553
item = session.get(Item, 1)
554-
assert item.half_embeddings[0].to_list() == [1, 2, 3]
555-
assert item.half_embeddings[1].to_list() == [4, 5, 6]
554+
assert item.half_embeddings[0] == HalfVector([1, 2, 3])
555+
assert item.half_embeddings[1] == HalfVector([4, 5, 6])
556556

557557

558558
@pytest.mark.parametrize('engine', async_engines)
@@ -582,7 +582,7 @@ async def test_halfvec(self, engine):
582582
embedding = [1, 2, 3]
583583
session.add(Item(id=1, half_embedding=embedding))
584584
item = await session.get(Item, 1)
585-
assert item.half_embedding.to_list() == embedding
585+
assert item.half_embedding == HalfVector(embedding)
586586

587587
await engine.dispose()
588588

@@ -608,7 +608,7 @@ async def test_sparsevec(self, engine):
608608
embedding = [1, 2, 3]
609609
session.add(Item(id=1, sparse_embedding=embedding))
610610
item = await session.get(Item, 1)
611-
assert item.sparse_embedding.to_list() == embedding
611+
assert item.sparse_embedding == SparseVector(embedding)
612612

613613
await engine.dispose()
614614

0 commit comments

Comments
 (0)