|
1 | 1 | import numpy as np |
2 | | -from pgvector.sqlalchemy import VECTOR, HALFVEC, BIT, SPARSEVEC, SparseVector |
| 2 | +from pgvector.sqlalchemy import VECTOR, HALFVEC, BIT, SPARSEVEC, SparseVector, avg, sum |
3 | 3 | import pytest |
4 | 4 | from sqlalchemy import create_engine, insert, inspect, select, text, MetaData, Table, Column, Index, Integer |
5 | 5 | from sqlalchemy.exc import StatementError |
@@ -339,41 +339,39 @@ def test_select_orm(self): |
339 | 339 |
|
340 | 340 | def test_avg(self): |
341 | 341 | with Session(engine) as session: |
342 | | - avg = session.query(func.avg(Item.embedding)).first()[0] |
343 | | - assert avg is None |
| 342 | + res = session.query(avg(Item.embedding)).first()[0] |
| 343 | + assert res is None |
344 | 344 | session.add(Item(embedding=[1, 2, 3])) |
345 | 345 | session.add(Item(embedding=[4, 5, 6])) |
346 | | - avg = session.query(func.avg(Item.embedding)).first()[0] |
347 | | - # does not type cast |
348 | | - assert avg == '[2.5,3.5,4.5]' |
| 346 | + res = session.query(avg(Item.embedding)).first()[0] |
| 347 | + assert np.array_equal(res, np.array([2.5, 3.5, 4.5])) |
349 | 348 |
|
350 | 349 | def test_avg_orm(self): |
351 | 350 | with Session(engine) as session: |
352 | | - avg = session.scalars(select(func.avg(Item.embedding))).first() |
353 | | - assert avg is None |
| 351 | + res = session.scalars(select(avg(Item.embedding))).first() |
| 352 | + assert res is None |
354 | 353 | session.add(Item(embedding=[1, 2, 3])) |
355 | 354 | session.add(Item(embedding=[4, 5, 6])) |
356 | | - avg = session.scalars(select(func.avg(Item.embedding))).first() |
357 | | - # does not type cast |
358 | | - assert avg == '[2.5,3.5,4.5]' |
| 355 | + res = session.scalars(select(avg(Item.embedding))).first() |
| 356 | + assert np.array_equal(res, np.array([2.5, 3.5, 4.5])) |
359 | 357 |
|
360 | 358 | def test_sum(self): |
361 | 359 | with Session(engine) as session: |
362 | | - sum = session.query(func.sum(Item.embedding)).first()[0] |
363 | | - assert sum is None |
| 360 | + res = session.query(sum(Item.embedding)).first()[0] |
| 361 | + assert res is None |
364 | 362 | session.add(Item(embedding=[1, 2, 3])) |
365 | 363 | session.add(Item(embedding=[4, 5, 6])) |
366 | | - sum = session.query(func.sum(Item.embedding)).first()[0] |
367 | | - assert np.array_equal(sum, np.array([5, 7, 9])) |
| 364 | + res = session.query(sum(Item.embedding)).first()[0] |
| 365 | + assert np.array_equal(res, np.array([5, 7, 9])) |
368 | 366 |
|
369 | 367 | def test_sum_orm(self): |
370 | 368 | with Session(engine) as session: |
371 | | - sum = session.scalars(select(func.sum(Item.embedding))).first() |
372 | | - assert sum is None |
| 369 | + res = session.scalars(select(sum(Item.embedding))).first() |
| 370 | + assert res is None |
373 | 371 | session.add(Item(embedding=[1, 2, 3])) |
374 | 372 | session.add(Item(embedding=[4, 5, 6])) |
375 | | - sum = session.scalars(select(func.sum(Item.embedding))).first() |
376 | | - assert np.array_equal(sum, np.array([5, 7, 9])) |
| 373 | + res = session.scalars(select(sum(Item.embedding))).first() |
| 374 | + assert np.array_equal(res, np.array([5, 7, 9])) |
377 | 375 |
|
378 | 376 | def test_bad_dimensions(self): |
379 | 377 | item = Item(embedding=[1, 2]) |
|
0 commit comments