Skip to content

Commit 022dd06

Browse files
committed
Improved tests [skip ci]
1 parent 6b8857a commit 022dd06

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

tests/test_sqlalchemy.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncpg
22
import numpy as np
33
import os
4-
from pgvector import HalfVector, SparseVector
4+
from pgvector import HalfVector, SparseVector, Vector
55
from pgvector.sqlalchemy import VECTOR, HALFVEC, BIT, SPARSEVEC, avg, sum
66
import pytest
77
from sqlalchemy import create_engine, event, insert, inspect, select, text, MetaData, Table, Column, Index, Integer, ARRAY
@@ -637,9 +637,14 @@ async def test_vector_array(self, engine):
637637

638638
async with async_session() as session:
639639
async with session.begin():
640-
session.add(Item(id=1, embeddings=[np.array([1, 2, 3]), np.array([4, 5, 6])]))
640+
session.add(Item(id=1, embeddings=[Vector([1, 2, 3]), Vector([4, 5, 6])]))
641641
item = await session.get(Item, 1)
642642
assert item.embeddings[0].tolist() == [1, 2, 3]
643643
assert item.embeddings[1].tolist() == [4, 5, 6]
644644

645+
session.add(Item(id=2, embeddings=[np.array([1, 2, 3]), np.array([4, 5, 6])]))
646+
item = await session.get(Item, 2)
647+
assert item.embeddings[0].tolist() == [1, 2, 3]
648+
assert item.embeddings[1].tolist() == [4, 5, 6]
649+
645650
await engine.dispose()

0 commit comments

Comments
 (0)