|
1 | 1 | import numpy as np |
2 | | -from pgvector.psycopg2 import register_vector, SparseVector |
| 2 | +from pgvector.psycopg2 import register_vector, HalfVector, SparseVector |
3 | 3 | import psycopg2 |
4 | 4 | from psycopg2.extras import DictCursor, RealDictCursor, NamedTupleCursor |
5 | 5 |
|
|
9 | 9 | cur = conn.cursor() |
10 | 10 | cur.execute('CREATE EXTENSION IF NOT EXISTS vector') |
11 | 11 | cur.execute('DROP TABLE IF EXISTS psycopg2_items') |
12 | | -cur.execute('CREATE TABLE psycopg2_items (id bigserial PRIMARY KEY, embedding vector(3), half_embedding halfvec(3), binary_embedding bit(3), sparse_embedding sparsevec(3))') |
| 12 | +cur.execute('CREATE TABLE psycopg2_items (id bigserial PRIMARY KEY, embedding vector(3), half_embedding halfvec(3), binary_embedding bit(3), sparse_embedding sparsevec(3), embeddings vector[], half_embeddings halfvec[], sparse_embeddings sparsevec[])') |
13 | 13 |
|
14 | | -register_vector(cur, globally=False) |
| 14 | +register_vector(cur, globally=False, arrays=True) |
15 | 15 |
|
16 | 16 |
|
17 | 17 | class TestPsycopg2: |
@@ -55,6 +55,33 @@ def test_sparsevec(self): |
55 | 55 | assert res[0][0].to_list() == [1.5, 2, 3] |
56 | 56 | assert res[1][0] is None |
57 | 57 |
|
| 58 | + def test_vector_array(self): |
| 59 | + embeddings = [np.array([1.5, 2, 3]), np.array([4.5, 5, 6])] |
| 60 | + cur.execute('INSERT INTO psycopg2_items (embeddings) VALUES (%s::vector[])', (embeddings,)) |
| 61 | + |
| 62 | + cur.execute('SELECT embeddings FROM psycopg2_items ORDER BY id') |
| 63 | + res = cur.fetchone() |
| 64 | + assert np.array_equal(res[0][0], embeddings[0]) |
| 65 | + assert np.array_equal(res[0][1], embeddings[1]) |
| 66 | + |
| 67 | + def test_halfvec_array(self): |
| 68 | + embeddings = [HalfVector([1.5, 2, 3]), HalfVector([4.5, 5, 6])] |
| 69 | + cur.execute('INSERT INTO psycopg2_items (half_embeddings) VALUES (%s::halfvec[])', (embeddings,)) |
| 70 | + |
| 71 | + cur.execute('SELECT half_embeddings FROM psycopg2_items ORDER BY id') |
| 72 | + res = cur.fetchone() |
| 73 | + assert res[0][0].to_list() == [1.5, 2, 3] |
| 74 | + assert res[0][1].to_list() == [4.5, 5, 6] |
| 75 | + |
| 76 | + def test_sparsevec_array(self): |
| 77 | + embeddings = [SparseVector([1.5, 2, 3]), SparseVector([4.5, 5, 6])] |
| 78 | + cur.execute('INSERT INTO psycopg2_items (sparse_embeddings) VALUES (%s::sparsevec[])', (embeddings,)) |
| 79 | + |
| 80 | + cur.execute('SELECT sparse_embeddings FROM psycopg2_items ORDER BY id') |
| 81 | + res = cur.fetchone() |
| 82 | + assert res[0][0].to_list() == [1.5, 2, 3] |
| 83 | + assert res[0][1].to_list() == [4.5, 5, 6] |
| 84 | + |
58 | 85 | def test_cursor_factory(self): |
59 | 86 | for cursor_factory in [DictCursor, RealDictCursor, NamedTupleCursor]: |
60 | 87 | conn = psycopg2.connect(dbname='pgvector_python_test') |
|
0 commit comments