|
1 | 1 | import numpy as np |
2 | 2 | from pgvector.sqlalchemy import VECTOR, HALFVEC, BIT, SPARSEVEC, SparseVector, avg, sum |
3 | 3 | import pytest |
4 | | -from sqlalchemy import create_engine, insert, inspect, select, text, MetaData, Table, Column, Index, Integer |
| 4 | +from sqlalchemy import create_engine, insert, inspect, select, text, MetaData, Table, Column, Index, Integer, ARRAY |
5 | 5 | from sqlalchemy.exc import StatementError |
6 | 6 | from sqlalchemy.ext.automap import automap_base |
7 | 7 | from sqlalchemy.orm import declarative_base, Session |
@@ -31,6 +31,7 @@ class Item(Base): |
31 | 31 | half_embedding = mapped_column(HALFVEC(3)) |
32 | 32 | binary_embedding = mapped_column(BIT(3)) |
33 | 33 | sparse_embedding = mapped_column(SPARSEVEC(3)) |
| 34 | + embeddings = mapped_column(ARRAY(VECTOR(3))) |
34 | 35 |
|
35 | 36 |
|
36 | 37 | Base.metadata.drop_all(engine) |
@@ -70,7 +71,8 @@ def test_core(self): |
70 | 71 | Column('embedding', VECTOR(3)), |
71 | 72 | Column('half_embedding', HALFVEC(3)), |
72 | 73 | Column('binary_embedding', BIT(3)), |
73 | | - Column('sparse_embedding', SPARSEVEC(3)) |
| 74 | + Column('sparse_embedding', SPARSEVEC(3)), |
| 75 | + Column('embeddings', ARRAY(VECTOR(3))) |
74 | 76 | ) |
75 | 77 |
|
76 | 78 | metadata.drop_all(engine) |
@@ -422,6 +424,14 @@ def test_automap(self): |
422 | 424 | item = session.query(AutoItem).first() |
423 | 425 | assert item.embedding.tolist() == [1, 2, 3] |
424 | 426 |
|
| 427 | + def test_vector_array(self): |
| 428 | + session = Session(engine) |
| 429 | + session.add(Item(id=1, embeddings=[np.array([1, 2, 3]), np.array([4, 5, 6])])) |
| 430 | + session.commit() |
| 431 | + |
| 432 | + # this fails if the driver does not cast arrays |
| 433 | + # item = session.get(Item, 1) |
| 434 | + |
425 | 435 | @pytest.mark.asyncio |
426 | 436 | @pytest.mark.skipif(sqlalchemy_version == 1, reason='Requires SQLAlchemy 2+') |
427 | 437 | async def test_async(self): |
|
0 commit comments