|
1 | 1 | import asyncpg |
2 | 2 | import numpy as np |
3 | 3 | import os |
4 | | -from pgvector import HalfVector, SparseVector |
| 4 | +from pgvector import HalfVector, SparseVector, Vector |
5 | 5 | from pgvector.sqlalchemy import VECTOR, HALFVEC, BIT, SPARSEVEC, avg, sum |
6 | 6 | import pytest |
7 | 7 | from sqlalchemy import create_engine, event, insert, inspect, select, text, MetaData, Table, Column, Index, Integer, ARRAY |
@@ -637,9 +637,14 @@ async def test_vector_array(self, engine): |
637 | 637 |
|
638 | 638 | async with async_session() as session: |
639 | 639 | async with session.begin(): |
640 | | - session.add(Item(id=1, embeddings=[np.array([1, 2, 3]), np.array([4, 5, 6])])) |
| 640 | + session.add(Item(id=1, embeddings=[Vector([1, 2, 3]), Vector([4, 5, 6])])) |
641 | 641 | item = await session.get(Item, 1) |
642 | 642 | assert item.embeddings[0].tolist() == [1, 2, 3] |
643 | 643 | assert item.embeddings[1].tolist() == [4, 5, 6] |
644 | 644 |
|
| 645 | + session.add(Item(id=2, embeddings=[np.array([1, 2, 3]), np.array([4, 5, 6])])) |
| 646 | + item = await session.get(Item, 2) |
| 647 | + assert item.embeddings[0].tolist() == [1, 2, 3] |
| 648 | + assert item.embeddings[1].tolist() == [4, 5, 6] |
| 649 | + |
645 | 650 | await engine.dispose() |
0 commit comments