Skip to content

Commit d84248f

Browse files
committed
Vector: Add software tests
1 parent aa7de6e commit d84248f

File tree

5 files changed

+244
-5
lines changed

5 files changed

+244
-5
lines changed

DEVELOP.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ further commands.
1616

1717
Verify code by running all linters and software tests:
1818

19+
export CRATEDB_VERSION=latest
1920
docker compose -f tests/docker-compose.yml up
2021
poe check
2122

docs/working-with-types.rst

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ from the CrateDB SQLAlchemy dialect. Currently, these are:
99

1010
- Container types ``ObjectType`` and ``ObjectArray``.
1111
- Geospatial types ``Geopoint`` and ``Geoshape``.
12+
- Vector data type ``FloatVector``.
1213

1314

1415
.. rubric:: Table of Contents
@@ -257,6 +258,33 @@ objects:
257258
[('Tokyo', (139.75999999791384, 35.67999996710569), {"coordinates": [[[139.806, 35.515], [139.919, 35.703], [139.768, 35.817], [139.575, 35.76], [139.584, 35.619], [139.806, 35.515]]], "type": "Polygon"})]
258259

259260

261+
Vector type
262+
===========
263+
264+
CrateDB's vector data type, :ref:`crate-reference:type-float_vector`,
265+
allows to store dense vectors of float values of fixed length.
266+
267+
>>> from sqlalchemy_cratedb.type.vector import FloatVector
268+
269+
>>> class SearchIndex(Base):
270+
... __tablename__ = 'search'
271+
... name = sa.Column(sa.String, primary_key=True)
272+
... embedding = sa.Column(FloatVector(3))
273+
274+
Create an entity and store it into the database. ``float_vector`` values
275+
can be defined by using arrays of floating point numbers.
276+
277+
>>> foo_item = SearchIndex(name="foo", embedding=[42.42, 43.43, 44.44])
278+
>>> session.add(foo_item)
279+
>>> session.commit()
280+
>>> _ = connection.execute(sa.text("REFRESH TABLE search"))
281+
282+
When reading it back, the ``FLOAT_VECTOR`` value will be returned as a NumPy array.
283+
284+
>>> query = session.query(SearchIndex.name, SearchIndex.embedding)
285+
>>> query.all()
286+
[('foo', array([42.42, 43.43, 44.44], dtype=float32))]
287+
260288
.. hidden: Disconnect from database
261289
262290
>>> session.close()

src/sqlalchemy_cratedb/type/vector.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
SQLAlchemy data type implementation for CrateDB's `FLOAT_VECTOR` type.
44
55
## References
6-
- https://crate.io/docs/crate/reference/en/master/general/ddl/data-types.html#float-vector
7-
- https://crate.io/docs/crate/reference/en/master/general/builtins/scalar-functions.html#scalar-knn-match
6+
- https://cratedb.com/docs/crate/reference/en/latest/general/ddl/data-types.html#float-vector
7+
- https://cratedb.com/docs/crate/reference/en/latest/general/builtins/scalar-functions.html#scalar-knn-match
88
99
## Details
1010
The implementation is based on SQLAlchemy's `TypeDecorator`, and also
@@ -14,8 +14,8 @@
1414
CrateDB currently only supports the similarity function `VectorSimilarityFunction.EUCLIDEAN`.
1515
-- https://github.com/crate/crate/blob/5.5.1/server/src/main/java/io/crate/types/FloatVectorType.java#L55
1616
17-
On the other hand, pgvector use a comparator to apply different similarity
18-
functions as operators, see `pgvector.sqlalchemy.Vector.comparator_factory`.
17+
pgvector use a comparator to apply different similarity functions as operators,
18+
see `pgvector.sqlalchemy.Vector.comparator_factory`.
1919
2020
<->: l2/euclidean_distance
2121
<#>: max_inner_product

tests/integration.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,13 @@ def provision_database():
105105
name STRING PRIMARY KEY,
106106
coordinate GEO_POINT,
107107
area GEO_SHAPE
108-
)"""
108+
)""",
109+
"""
110+
CREATE TABLE search (
111+
name STRING PRIMARY KEY,
112+
text STRING,
113+
embedding FLOAT_VECTOR(3)
114+
)""",
109115
]
110116
_execute_statements(ddl_statements)
111117

@@ -120,6 +126,7 @@ def drop_tables():
120126
"DROP TABLE IF EXISTS cities",
121127
"DROP TABLE IF EXISTS locations",
122128
"DROP BLOB TABLE IF EXISTS myfiles",
129+
"DROP TABLE IF EXISTS search",
123130
'DROP TABLE IF EXISTS "test-testdrive"',
124131
"DROP TABLE IF EXISTS todos",
125132
'DROP TABLE IF EXISTS "user"',

tests/vector_test.py

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
# -*- coding: utf-8; -*-
2+
#
3+
# Licensed to CRATE Technology GmbH ("Crate") under one or more contributor
4+
# license agreements. See the NOTICE file distributed with this work for
5+
# additional information regarding copyright ownership. Crate licenses
6+
# this file to you under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License. You may
8+
# obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
14+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
15+
# License for the specific language governing permissions and limitations
16+
# under the License.
17+
#
18+
# However, if you have executed another commercial license agreement
19+
# with Crate these terms will supersede the license and you may use the
20+
# software solely pursuant to the terms of the relevant commercial agreement.
21+
22+
from __future__ import absolute_import
23+
24+
import re
25+
import sys
26+
from unittest import TestCase
27+
from unittest.mock import MagicMock, patch
28+
29+
import pytest
30+
import sqlalchemy as sa
31+
from sqlalchemy.orm import Session
32+
from sqlalchemy.sql import select
33+
34+
from sqlalchemy_cratedb import SA_VERSION, SA_1_4
35+
from sqlalchemy_cratedb.type import FloatVector
36+
37+
from crate.client.cursor import Cursor
38+
39+
from sqlalchemy_cratedb.type.vector import from_db, to_db
40+
41+
fake_cursor = MagicMock(name="fake_cursor")
42+
FakeCursor = MagicMock(name="FakeCursor", spec=Cursor)
43+
FakeCursor.return_value = fake_cursor
44+
45+
46+
if SA_VERSION < SA_1_4:
47+
pytest.skip(reason="The FloatVector type is not supported on SQLAlchemy 1.3 and earlier", allow_module_level=True)
48+
49+
50+
@patch("crate.client.connection.Cursor", FakeCursor)
51+
class SqlAlchemyVectorTypeTest(TestCase):
52+
"""
53+
Verify compilation of SQL statements where the schema includes the `FloatVector` type.
54+
"""
55+
def setUp(self):
56+
self.engine = sa.create_engine("crate://")
57+
metadata = sa.MetaData()
58+
self.table = sa.Table(
59+
"testdrive",
60+
metadata,
61+
sa.Column("name", sa.String),
62+
sa.Column("data", FloatVector(3)),
63+
)
64+
self.session = Session(bind=self.engine)
65+
66+
def assertSQL(self, expected_str, actual_expr):
67+
self.assertEqual(expected_str, str(actual_expr).replace('\n', ''))
68+
69+
def test_create_invoke(self):
70+
self.table.create(self.engine)
71+
fake_cursor.execute.assert_called_with(
72+
(
73+
"\nCREATE TABLE testdrive (\n\t"
74+
"name STRING, \n\t"
75+
"data FLOAT_VECTOR(3)\n)\n\n"
76+
),
77+
(),
78+
)
79+
80+
def test_insert_invoke(self):
81+
stmt = self.table.insert().values(
82+
name="foo", data=[42.42, 43.43, 44.44]
83+
)
84+
with self.engine.connect() as conn:
85+
conn.execute(stmt)
86+
fake_cursor.execute.assert_called_with(
87+
("INSERT INTO testdrive (name, data) VALUES (?, ?)"),
88+
("foo", [42.42, 43.43, 44.44]),
89+
)
90+
91+
def test_select_invoke(self):
92+
stmt = select(self.table.c.data)
93+
with self.engine.connect() as conn:
94+
conn.execute(stmt)
95+
fake_cursor.execute.assert_called_with(
96+
("SELECT testdrive.data \nFROM testdrive"),
97+
(),
98+
)
99+
100+
def test_sql_select(self):
101+
self.assertSQL(
102+
"SELECT testdrive.data FROM testdrive", select(self.table.c.data)
103+
)
104+
105+
106+
def test_from_db_success():
107+
"""
108+
Verify succeeding uses of `sqlalchemy_cratedb.type.vector.from_db`.
109+
"""
110+
np = pytest.importorskip("numpy")
111+
assert from_db(None) is None
112+
assert np.array_equal(from_db(False), np.array(0., dtype=np.float32))
113+
assert np.array_equal(from_db(True), np.array(1., dtype=np.float32))
114+
assert np.array_equal(from_db(42), np.array(42, dtype=np.float32))
115+
assert np.array_equal(from_db(42.42), np.array(42.42, dtype=np.float32))
116+
assert np.array_equal(from_db([42.42, 43.43]), np.array([42.42, 43.43], dtype=np.float32))
117+
assert np.array_equal(from_db("42.42"), np.array(42.42, dtype=np.float32))
118+
assert np.array_equal(from_db(["42.42", "43.43"]), np.array([42.42, 43.43], dtype=np.float32))
119+
120+
121+
def test_from_db_failure():
122+
"""
123+
Verify failing uses of `sqlalchemy_cratedb.type.vector.from_db`.
124+
"""
125+
pytest.importorskip("numpy")
126+
127+
with pytest.raises(ValueError) as ex:
128+
from_db("foo")
129+
assert ex.match("could not convert string to float: 'foo'")
130+
131+
with pytest.raises(ValueError) as ex:
132+
from_db(["foo"])
133+
assert ex.match("could not convert string to float: 'foo'")
134+
135+
with pytest.raises(TypeError) as ex:
136+
from_db({"foo": "bar"})
137+
if sys.version_info < (3, 10):
138+
assert ex.match(re.escape("float() argument must be a string or a number, not 'dict'"))
139+
else:
140+
assert ex.match(re.escape("float() argument must be a string or a real number, not 'dict'"))
141+
142+
143+
def test_to_db_success():
144+
"""
145+
Verify succeeding uses of `sqlalchemy_cratedb.type.vector.to_db`.
146+
"""
147+
np = pytest.importorskip("numpy")
148+
assert to_db(None) is None
149+
assert to_db(False) is False
150+
assert to_db(True) is True
151+
assert to_db(42) == 42
152+
assert to_db(42.42) == 42.42
153+
assert to_db([42.42, 43.43]) == [42.42, 43.43]
154+
assert to_db(np.array([42.42, 43.43])) == [42.42, 43.43]
155+
assert to_db("42.42") == "42.42"
156+
assert to_db("foo") == "foo"
157+
assert to_db(["foo"]) == ["foo"]
158+
assert to_db({"foo": "bar"}) == {"foo": "bar"}
159+
assert isinstance(to_db(object()), object)
160+
161+
162+
def test_to_db_failure():
163+
"""
164+
Verify failing uses of `sqlalchemy_cratedb.type.vector.to_db`.
165+
"""
166+
np = pytest.importorskip("numpy")
167+
168+
with pytest.raises(ValueError) as ex:
169+
to_db(np.array(["42.42", "43.43"]))
170+
assert ex.match("dtype must be numeric")
171+
172+
with pytest.raises(ValueError) as ex:
173+
to_db(np.array([42.42, 43.43]), dim=33)
174+
assert ex.match("expected 33 dimensions, not 2")
175+
176+
with pytest.raises(ValueError) as ex:
177+
to_db(np.array([[42.42, 43.43]]))
178+
assert ex.match("expected ndim to be 1")
179+
180+
181+
def test_float_vector_no_dimension_size():
182+
"""
183+
Verify a FloatVector can not be initialized without a dimension size.
184+
"""
185+
engine = sa.create_engine("crate://")
186+
metadata = sa.MetaData()
187+
table = sa.Table(
188+
"foo",
189+
metadata,
190+
sa.Column("data", FloatVector),
191+
)
192+
with pytest.raises(ValueError) as ex:
193+
table.create(engine)
194+
ex.match("FloatVector must be initialized with dimension size")
195+
196+
197+
def test_float_vector_as_generic():
198+
"""
199+
Verify the `as_generic` and `python_type` method/property on the FloatVector type object.
200+
"""
201+
fv = FloatVector(3)
202+
assert isinstance(fv.as_generic(), sa.ARRAY)
203+
assert fv.python_type is list

0 commit comments

Comments
 (0)