Skip to content

Commit 82b7ab9

Browse files
committed
Added arrays option for Psycopg 2
1 parent 0852a1f commit 82b7ab9

File tree

6 files changed

+61
-14
lines changed

6 files changed

+61
-14
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
## 0.3.6 (unreleased)
2+
3+
- Added `arrays` option for Psycopg 2
4+
15
## 0.3.5 (2024-10-05)
26

37
- Added `avg` function with type casting to SQLAlchemy

pgvector/psycopg2/halfvec.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from psycopg2.extensions import adapt, new_type, register_adapter, register_type
1+
from psycopg2.extensions import adapt, new_array_type, new_type, register_adapter, register_type
22
from ..utils import HalfVector
33

44

@@ -14,7 +14,12 @@ def cast_halfvec(value, cur):
1414
return HalfVector._from_db(value)
1515

1616

17-
def register_halfvec_info(oid, scope):
17+
def register_halfvec_info(oid, array_oid, scope):
1818
halfvec = new_type((oid,), 'HALFVEC', cast_halfvec)
1919
register_type(halfvec, scope)
20+
21+
if array_oid is not None:
22+
halfvecarray = new_array_type((array_oid,), 'HALFVECARRAY', halfvec)
23+
register_type(halfvecarray, scope)
24+
2025
register_adapter(HalfVector, HalfvecAdapter)

pgvector/psycopg2/register.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,23 @@
77

88
# TODO make globally False by default in 0.4.0
99
# note: register_adapter is always global
10-
def register_vector(conn_or_curs=None, globally=True):
10+
# TODO make arrays True by defalt in 0.4.0
11+
def register_vector(conn_or_curs=None, globally=True, arrays=False):
1112
conn = conn_or_curs if hasattr(conn_or_curs, 'cursor') else conn_or_curs.connection
1213
cur = conn.cursor(cursor_factory=cursor)
1314
scope = None if globally else conn_or_curs
1415

1516
# use to_regtype to get first matching type in search path
16-
cur.execute("SELECT typname, oid FROM pg_type WHERE oid IN (to_regtype('vector'), to_regtype('halfvec'), to_regtype('sparsevec'))")
17+
cur.execute("SELECT typname, oid FROM pg_type WHERE oid IN (to_regtype('vector'), to_regtype('_vector'), to_regtype('halfvec'), to_regtype('_halfvec'), to_regtype('sparsevec'), to_regtype('_sparsevec'))")
1718
type_info = dict(cur.fetchall())
1819

1920
if 'vector' not in type_info:
2021
raise psycopg2.ProgrammingError('vector type not found in the database')
2122

22-
register_vector_info(type_info['vector'], scope)
23+
register_vector_info(type_info['vector'], type_info['_vector'] if arrays else None, scope)
2324

2425
if 'halfvec' in type_info:
25-
register_halfvec_info(type_info['halfvec'], scope)
26+
register_halfvec_info(type_info['halfvec'], type_info['_halfvec'] if arrays else None, scope)
2627

2728
if 'sparsevec' in type_info:
28-
register_sparsevec_info(type_info['sparsevec'], scope)
29+
register_sparsevec_info(type_info['sparsevec'], type_info['_sparsevec'] if arrays else None, scope)

pgvector/psycopg2/sparsevec.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from psycopg2.extensions import adapt, new_type, register_adapter, register_type
1+
from psycopg2.extensions import adapt, new_array_type, new_type, register_adapter, register_type
22
from ..utils import SparseVector
33

44

@@ -14,7 +14,12 @@ def cast_sparsevec(value, cur):
1414
return SparseVector._from_db(value)
1515

1616

17-
def register_sparsevec_info(oid, scope):
17+
def register_sparsevec_info(oid, array_oid, scope):
1818
sparsevec = new_type((oid,), 'SPARSEVEC', cast_sparsevec)
1919
register_type(sparsevec, scope)
20+
21+
if array_oid is not None:
22+
sparsevecarray = new_array_type((array_oid,), 'SPARSEVECARRAY', sparsevec)
23+
register_type(sparsevecarray, scope)
24+
2025
register_adapter(SparseVector, SparsevecAdapter)

pgvector/psycopg2/vector.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
from psycopg2.extensions import adapt, new_type, register_adapter, register_type
2+
from psycopg2.extensions import adapt, new_array_type, new_type, register_adapter, register_type
33
from ..utils import Vector
44

55

@@ -15,7 +15,12 @@ def cast_vector(value, cur):
1515
return Vector._from_db(value)
1616

1717

18-
def register_vector_info(oid, scope):
18+
def register_vector_info(oid, array_oid, scope):
1919
vector = new_type((oid,), 'VECTOR', cast_vector)
2020
register_type(vector, scope)
21+
22+
if array_oid is not None:
23+
vectorarray = new_array_type((array_oid,), 'VECTORARRAY', vector)
24+
register_type(vectorarray, scope)
25+
2126
register_adapter(np.ndarray, VectorAdapter)

tests/test_psycopg2.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
from pgvector.psycopg2 import register_vector, SparseVector
2+
from pgvector.psycopg2 import register_vector, HalfVector, SparseVector
33
import psycopg2
44
from psycopg2.extras import DictCursor, RealDictCursor, NamedTupleCursor
55

@@ -9,9 +9,9 @@
99
cur = conn.cursor()
1010
cur.execute('CREATE EXTENSION IF NOT EXISTS vector')
1111
cur.execute('DROP TABLE IF EXISTS psycopg2_items')
12-
cur.execute('CREATE TABLE psycopg2_items (id bigserial PRIMARY KEY, embedding vector(3), half_embedding halfvec(3), binary_embedding bit(3), sparse_embedding sparsevec(3))')
12+
cur.execute('CREATE TABLE psycopg2_items (id bigserial PRIMARY KEY, embedding vector(3), half_embedding halfvec(3), binary_embedding bit(3), sparse_embedding sparsevec(3), embeddings vector[], half_embeddings halfvec[], sparse_embeddings sparsevec[])')
1313

14-
register_vector(cur, globally=False)
14+
register_vector(cur, globally=False, arrays=True)
1515

1616

1717
class TestPsycopg2:
@@ -55,6 +55,33 @@ def test_sparsevec(self):
5555
assert res[0][0].to_list() == [1.5, 2, 3]
5656
assert res[1][0] is None
5757

58+
def test_vector_array(self):
59+
embeddings = [np.array([1.5, 2, 3]), np.array([4.5, 5, 6])]
60+
cur.execute('INSERT INTO psycopg2_items (embeddings) VALUES (%s::vector[])', (embeddings,))
61+
62+
cur.execute('SELECT embeddings FROM psycopg2_items ORDER BY id')
63+
res = cur.fetchone()
64+
assert np.array_equal(res[0][0], embeddings[0])
65+
assert np.array_equal(res[0][1], embeddings[1])
66+
67+
def test_halfvec_array(self):
68+
embeddings = [HalfVector([1.5, 2, 3]), HalfVector([4.5, 5, 6])]
69+
cur.execute('INSERT INTO psycopg2_items (half_embeddings) VALUES (%s::halfvec[])', (embeddings,))
70+
71+
cur.execute('SELECT half_embeddings FROM psycopg2_items ORDER BY id')
72+
res = cur.fetchone()
73+
assert res[0][0].to_list() == [1.5, 2, 3]
74+
assert res[0][1].to_list() == [4.5, 5, 6]
75+
76+
def test_sparsevec_array(self):
77+
embeddings = [SparseVector([1.5, 2, 3]), SparseVector([4.5, 5, 6])]
78+
cur.execute('INSERT INTO psycopg2_items (sparse_embeddings) VALUES (%s::sparsevec[])', (embeddings,))
79+
80+
cur.execute('SELECT sparse_embeddings FROM psycopg2_items ORDER BY id')
81+
res = cur.fetchone()
82+
assert res[0][0].to_list() == [1.5, 2, 3]
83+
assert res[0][1].to_list() == [4.5, 5, 6]
84+
5885
def test_cursor_factory(self):
5986
for cursor_factory in [DictCursor, RealDictCursor, NamedTupleCursor]:
6087
conn = psycopg2.connect(dbname='pgvector_python_test')

0 commit comments

Comments
 (0)