Skip to content

Commit 70ae49a

Browse files
committed
Updated tests to use globally=False
1 parent b88ebed commit 70ae49a

File tree

4 files changed

+19
-11
lines changed

4 files changed

+19
-11
lines changed

tests/test_peewee.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,31 +169,35 @@ def test_vector_avg(self):
169169
Item.create(embedding=[1, 2, 3])
170170
Item.create(embedding=[4, 5, 6])
171171
avg = Item.select(fn.avg(Item.embedding)).scalar()
172-
assert np.array_equal(avg, np.array([2.5, 3.5, 4.5]))
172+
# does not type cast
173+
assert avg == '[2.5,3.5,4.5]'
173174

174175
def test_vector_sum(self):
175176
sum = Item.select(fn.sum(Item.embedding)).scalar()
176177
assert sum is None
177178
Item.create(embedding=[1, 2, 3])
178179
Item.create(embedding=[4, 5, 6])
179180
sum = Item.select(fn.sum(Item.embedding)).scalar()
180-
assert np.array_equal(sum, np.array([5, 7, 9]))
181+
# does not type cast
182+
assert sum == '[5,7,9]'
181183

182184
def test_halfvec_avg(self):
183185
avg = Item.select(fn.avg(Item.half_embedding)).scalar()
184186
assert avg is None
185187
Item.create(half_embedding=[1, 2, 3])
186188
Item.create(half_embedding=[4, 5, 6])
187189
avg = Item.select(fn.avg(Item.half_embedding)).scalar()
188-
assert avg.to_list() == [2.5, 3.5, 4.5]
190+
# does not type cast
191+
assert avg == '[2.5,3.5,4.5]'
189192

190193
def test_halfvec_sum(self):
191194
sum = Item.select(fn.sum(Item.half_embedding)).scalar()
192195
assert sum is None
193196
Item.create(half_embedding=[1, 2, 3])
194197
Item.create(half_embedding=[4, 5, 6])
195198
sum = Item.select(fn.sum(Item.half_embedding)).scalar()
196-
assert sum.to_list() == [5, 7, 9]
199+
# does not type cast
200+
assert sum == '[5,7,9]'
197201

198202
def test_get_or_create(self):
199203
Item.get_or_create(id=1, defaults={'embedding': [1, 2, 3]})

tests/test_psycopg2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
cur.execute('DROP TABLE IF EXISTS psycopg2_items')
1212
cur.execute('CREATE TABLE psycopg2_items (id bigserial PRIMARY KEY, embedding vector(3), half_embedding halfvec(3), binary_embedding bit(3), sparse_embedding sparsevec(3))')
1313

14-
register_vector(cur)
14+
register_vector(cur, globally=False)
1515

1616

1717
class TestPsycopg2:
@@ -59,11 +59,11 @@ def test_cursor_factory(self):
5959
for cursor_factory in [DictCursor, RealDictCursor, NamedTupleCursor]:
6060
conn = psycopg2.connect(dbname='pgvector_python_test')
6161
cur = conn.cursor(cursor_factory=cursor_factory)
62-
register_vector(cur)
62+
register_vector(cur, globally=False)
6363
conn.close()
6464

6565
def test_cursor_factory_connection(self):
6666
for cursor_factory in [DictCursor, RealDictCursor, NamedTupleCursor]:
6767
conn = psycopg2.connect(dbname='pgvector_python_test', cursor_factory=cursor_factory)
68-
register_vector(conn)
68+
register_vector(conn, globally=False)
6969
conn.close()

tests/test_sqlalchemy.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,8 @@ def test_avg(self):
337337
session.add(Item(embedding=[1, 2, 3]))
338338
session.add(Item(embedding=[4, 5, 6]))
339339
avg = session.query(func.avg(Item.embedding)).first()[0]
340-
assert np.array_equal(avg, np.array([2.5, 3.5, 4.5]))
340+
# does not type cast
341+
assert avg == '[2.5,3.5,4.5]'
341342

342343
def test_avg_orm(self):
343344
with Session(engine) as session:
@@ -346,7 +347,8 @@ def test_avg_orm(self):
346347
session.add(Item(embedding=[1, 2, 3]))
347348
session.add(Item(embedding=[4, 5, 6]))
348349
avg = session.scalars(select(func.avg(Item.embedding))).first()
349-
assert np.array_equal(avg, np.array([2.5, 3.5, 4.5]))
350+
# does not type cast
351+
assert avg == '[2.5,3.5,4.5]'
350352

351353
def test_sum(self):
352354
with Session(engine) as session:

tests/test_sqlmodel.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,8 @@ def test_vector_avg(self):
203203
session.add(Item(embedding=[1, 2, 3]))
204204
session.add(Item(embedding=[4, 5, 6]))
205205
avg = session.exec(select(func.avg(Item.embedding))).first()
206-
assert np.array_equal(avg, np.array([2.5, 3.5, 4.5]))
206+
# does not type cast
207+
assert avg == '[2.5,3.5,4.5]'
207208

208209
def test_vector_sum(self):
209210
with Session(engine) as session:
@@ -221,7 +222,8 @@ def test_halfvec_avg(self):
221222
session.add(Item(half_embedding=[1, 2, 3]))
222223
session.add(Item(half_embedding=[4, 5, 6]))
223224
avg = session.exec(select(func.avg(Item.half_embedding))).first()
224-
assert avg.to_list() == [2.5, 3.5, 4.5]
225+
# does not type cast
226+
assert avg == '[2.5,3.5,4.5]'
225227

226228
def test_halfvec_sum(self):
227229
with Session(engine) as session:

0 commit comments

Comments
 (0)