Skip to content

Commit da977d9

Browse files
committed
Vector: Add wrapper for HNSW matching function KNN_MATCH
1 parent d84248f commit da977d9

File tree

5 files changed

+69
-15
lines changed

5 files changed

+69
-15
lines changed

docs/working-with-types.rst

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ Vector type
264264
CrateDB's vector data type, :ref:`crate-reference:type-float_vector`,
265265
allows to store dense vectors of float values of fixed length.
266266

267-
>>> from sqlalchemy_cratedb.type.vector import FloatVector
267+
>>> from sqlalchemy_cratedb import FloatVector, knn_match
268268

269269
>>> class SearchIndex(Base):
270270
... __tablename__ = 'search'
@@ -285,6 +285,14 @@ When reading it back, the ``FLOAT_VECTOR`` value will be returned as a NumPy arr
285285
>>> query.all()
286286
[('foo', array([42.42, 43.43, 44.44], dtype=float32))]
287287

288+
In order to apply search, i.e. to match embeddings against each other, use the
289+
:ref:`crate-reference:scalar_knn_match` function like this.
290+
291+
>>> query = session.query(SearchIndex.name) \
292+
... .filter(knn_match(SearchIndex.embedding, [42.42, 43.43, 41.41], 3))
293+
>>> query.all()
294+
[('foo',)]
295+
288296
.. hidden: Disconnect from database
289297
290298
>>> session.close()

src/sqlalchemy_cratedb/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from .type.array import ObjectArray
2828
from .type.geo import Geopoint, Geoshape
2929
from .type.object import ObjectType
30-
from .type.vector import FloatVector
30+
from .type.vector import FloatVector, knn_match
3131

3232
if SA_VERSION < SA_1_4:
3333
import textwrap
@@ -58,4 +58,5 @@
5858
ObjectArray,
5959
ObjectType,
6060
match,
61+
knn_match,
6162
]
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .array import ObjectArray
22
from .geo import Geopoint, Geoshape
33
from .object import ObjectType
4-
from .vector import FloatVector
4+
from .vector import FloatVector, knn_match

src/sqlalchemy_cratedb/type/vector.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,6 @@
2222
<=>: cosine_distance
2323
2424
## Backlog
25-
- The type implementation might want to be accompanied by corresponding support
26-
for the `KNN_MATCH` function, similar to what the dialect already offers for
27-
fulltext search through its `Match` predicate.
2825
- After dropping support for SQLAlchemy 1.3, use
2926
`class FloatVector(sa.TypeDecorator[t.Sequence[float]]):`
3027
@@ -42,10 +39,13 @@
4239
import numpy.typing as npt # pragma: no cover
4340

4441
import sqlalchemy as sa
42+
from sqlalchemy.sql.expression import ColumnElement, literal
43+
from sqlalchemy.ext.compiler import compiles
4544

4645

4746
__all__ = [
4847
"from_db",
48+
"knn_match",
4949
"to_db",
5050
"FloatVector",
5151
]
@@ -131,7 +131,7 @@ class KnnMatch(ColumnElement):
131131
inherit_cache = True
132132

133133
def __init__(self, column, term, k=None):
134-
super(KnnMatch, self).__init__()
134+
super().__init__()
135135
self.column = column
136136
self.term = term
137137
self.k = k
@@ -150,11 +150,10 @@ def knn_match(column, term, k):
150150
"""
151151
Generate a match predicate for vector search.
152152
153-
:param column: A reference to a column or an index, or a subcolumn, or a
154-
dictionary of subcolumns with boost values.
153+
:param column: A reference to a column or an index.
155154
156155
:param term: The term to match against. This is an array of floating point
157-
values, which is compared to other vectors using a HNSW index.
156+
values, which is compared to other vectors using a HNSW index search.
158157
159158
:param k: The `k` argument determines the number of nearest neighbours to
160159
search in the index.
@@ -165,9 +164,9 @@ def knn_match(column, term, k):
165164
@compiles(KnnMatch)
166165
def compile_knn_match(knn_match, compiler, **kwargs):
167166
"""
168-
Clause compiler for `knn_match`.
167+
Clause compiler for `KNN_MATCH`.
169168
"""
170-
return "knn_match(%s, %s, %s)" % (
169+
return "KNN_MATCH(%s, %s, %s)" % (
171170
knn_match.compile_column(compiler),
172171
knn_match.compile_term(compiler),
173172
knn_match.compile_k(compiler),

tests/vector_test.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,18 @@
2828

2929
import pytest
3030
import sqlalchemy as sa
31-
from sqlalchemy.orm import Session
31+
from sqlalchemy.orm import Session, sessionmaker
3232
from sqlalchemy.sql import select
3333

34-
from sqlalchemy_cratedb import SA_VERSION, SA_1_4
35-
from sqlalchemy_cratedb.type import FloatVector
34+
try:
35+
from sqlalchemy.orm import declarative_base
36+
except ImportError:
37+
from sqlalchemy.ext.declarative import declarative_base
3638

3739
from crate.client.cursor import Cursor
3840

41+
from sqlalchemy_cratedb import SA_VERSION, SA_1_4
42+
from sqlalchemy_cratedb import FloatVector, knn_match
3943
from sqlalchemy_cratedb.type.vector import from_db, to_db
4044

4145
fake_cursor = MagicMock(name="fake_cursor")
@@ -102,6 +106,14 @@ def test_sql_select(self):
102106
"SELECT testdrive.data FROM testdrive", select(self.table.c.data)
103107
)
104108

109+
def test_sql_match(self):
110+
query = self.session.query(self.table.c.name) \
111+
.filter(knn_match(self.table.c.data, [42.42, 43.43], 3))
112+
self.assertSQL(
113+
"SELECT testdrive.name AS testdrive_name FROM testdrive WHERE KNN_MATCH(testdrive.data, ?, ?)",
114+
query
115+
)
116+
105117

106118
def test_from_db_success():
107119
"""
@@ -201,3 +213,37 @@ def test_float_vector_as_generic():
201213
fv = FloatVector(3)
202214
assert isinstance(fv.as_generic(), sa.ARRAY)
203215
assert fv.python_type is list
216+
217+
218+
def test_float_vector_integration():
219+
"""
220+
An integration test for `FLOAT_VECTOR` and `KNN_SEARCH`.
221+
"""
222+
np = pytest.importorskip("numpy")
223+
224+
engine = sa.create_engine(f"crate://")
225+
session = sessionmaker(bind=engine)()
226+
Base = declarative_base()
227+
228+
# Define DDL.
229+
class SearchIndex(Base):
230+
__tablename__ = 'search'
231+
name = sa.Column(sa.String, primary_key=True)
232+
embedding = sa.Column(FloatVector(3))
233+
234+
Base.metadata.drop_all(engine, checkfirst=True)
235+
Base.metadata.create_all(engine, checkfirst=True)
236+
237+
# Insert record.
238+
foo_item = SearchIndex(name="foo", embedding=[42.42, 43.43, 44.44])
239+
session.add(foo_item)
240+
session.commit()
241+
session.execute(sa.text("REFRESH TABLE search"))
242+
243+
# Query record.
244+
query = session.query(SearchIndex.embedding) \
245+
.filter(knn_match(SearchIndex.embedding, [42.42, 43.43, 41.41], 3))
246+
result = query.first()
247+
248+
# Compare outcome.
249+
assert np.array_equal(result.embedding, np.array([42.42, 43.43, 44.44], np.float32))

0 commit comments

Comments
 (0)