Skip to content

Commit 8a621a3

Browse files
committed
Changed default values of globally and arrays for register_type with Psycopg 2 [skip ci]
1 parent 37b148f commit 8a621a3

File tree

3 files changed

+6
-8
lines changed

3 files changed

+6
-8
lines changed

pgvector/psycopg2/register.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,8 @@
55
from .vector import register_vector_info
66

77

8-
# TODO make globally False by default in 0.4.0
98
# note: register_adapter is always global
10-
# TODO make arrays True by defalt in 0.4.0
11-
def register_vector(conn_or_curs, globally=True, arrays=False):
9+
def register_vector(conn_or_curs, globally=False, arrays=True):
1210
conn = conn_or_curs if hasattr(conn_or_curs, 'cursor') else conn_or_curs.connection
1311
cur = conn.cursor(cursor_factory=cursor)
1412
scope = None if globally else conn_or_curs

tests/test_psycopg2.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
cur.execute('DROP TABLE IF EXISTS psycopg2_items')
1313
cur.execute('CREATE TABLE psycopg2_items (id bigserial PRIMARY KEY, embedding vector(3), half_embedding halfvec(3), binary_embedding bit(3), sparse_embedding sparsevec(3), embeddings vector[], half_embeddings halfvec[], sparse_embeddings sparsevec[])')
1414

15-
register_vector(cur, globally=False, arrays=True)
15+
register_vector(cur)
1616

1717

1818
class TestPsycopg2:
@@ -87,13 +87,13 @@ def test_cursor_factory(self):
8787
for cursor_factory in [DictCursor, RealDictCursor, NamedTupleCursor]:
8888
conn = psycopg2.connect(dbname='pgvector_python_test')
8989
cur = conn.cursor(cursor_factory=cursor_factory)
90-
register_vector(cur, globally=False)
90+
register_vector(cur)
9191
conn.close()
9292

9393
def test_cursor_factory_connection(self):
9494
for cursor_factory in [DictCursor, RealDictCursor, NamedTupleCursor]:
9595
conn = psycopg2.connect(dbname='pgvector_python_test', cursor_factory=cursor_factory)
96-
register_vector(conn, globally=False)
96+
register_vector(conn)
9797
conn.close()
9898

9999
def test_pool(self):
@@ -102,7 +102,7 @@ def test_pool(self):
102102
conn = pool.getconn()
103103
try:
104104
# use globally=True for apps to ensure registered with all connections
105-
register_vector(conn, globally=False)
105+
register_vector(conn)
106106
finally:
107107
pool.putconn(conn)
108108

tests/test_sqlalchemy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
@event.listens_for(psycopg2_type_engine, "connect")
2525
def psycopg2_connect(dbapi_connection, connection_record):
2626
from pgvector.psycopg2 import register_vector
27-
register_vector(dbapi_connection, globally=False, arrays=True)
27+
register_vector(dbapi_connection)
2828

2929

3030
pg8000_engine = create_engine(f'postgresql+pg8000://{os.environ["USER"]}@localhost/pgvector_python_test')

0 commit comments

Comments
 (0)