Skip to content

Commit 174ec18

Browse files
committed
Vector: Add support for CrateDB's FLOAT_VECTOR data type: FloatVector
https://crate.io/docs/crate/reference/en/master/general/ddl/data-types.html#float-vector
1 parent 6b64619 commit 174ec18

File tree

9 files changed

+202
-4
lines changed

9 files changed

+202
-4
lines changed

CHANGES.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22

33

44
## Unreleased
5+
- Added support for CrateDB's [FLOAT_VECTOR] data type and its accompanying
6+
[KNN_MATCH] function, for HNSW matches. For SQLAlchemy column definitions,
7+
you can use it like `FloatVector(dimensions=1536)`.
8+
9+
[FLOAT_VECTOR]: https://cratedb.com/docs/crate/reference/en/latest/general/ddl/data-types.html#float-vector
10+
[KNN_MATCH]: https://cratedb.com/docs/crate/reference/en/latest/general/builtins/scalar-functions.html#scalar-knn-match
511

612
## 2024/06/11 0.36.1
713
- Dependencies: Use `crate==1.0.0dev0`

docs/data-types.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ CrateDB SQLAlchemy
4545
`integer`__ `Integer`__
4646
`long`__ `NUMERIC`__
4747
`float`__ `Float`__
48+
`float_vector`__ ``FloatVector``
4849
`double`__ `DECIMAL`__
4950
`timestamp`__ `TIMESTAMP`__
5051
`string`__ `String`__
@@ -68,6 +69,7 @@ __ https://crate.io/docs/crate/reference/en/latest/general/ddl/data-types.html#n
6869
__ http://docs.sqlalchemy.org/en/latest/core/type_basics.html#sqlalchemy.types.NUMERIC
6970
__ https://crate.io/docs/crate/reference/en/latest/general/ddl/data-types.html#numeric-data
7071
__ http://docs.sqlalchemy.org/en/latest/core/type_basics.html#sqlalchemy.types.Float
72+
__ https://cratedb.com/docs/crate/reference/en/latest/general/ddl/data-types.html#float-vector
7173
__ https://crate.io/docs/crate/reference/en/latest/general/ddl/data-types.html#numeric-data
7274
__ http://docs.sqlalchemy.org/en/latest/core/type_basics.html#sqlalchemy.types.DECIMAL
7375
__ https://crate.io/docs/crate/reference/en/latest/general/ddl/data-types.html#dates-and-times

docs/index.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,9 @@ Install package from PyPI.
5050
pip install sqlalchemy-cratedb
5151
5252
The CrateDB dialect for `SQLAlchemy`_ offers convenient ORM access and supports
53-
CrateDB's ``OBJECT``, ``ARRAY``, and geospatial data types using `GeoJSON`_,
54-
supporting different kinds of `GeoJSON geometry objects`_.
53+
CrateDB's container data types ``OBJECT`` and ``ARRAY``, its vector data type
54+
``FLOAT_VECTOR``, and geospatial data types using `GeoJSON`_, supporting different
55+
kinds of `GeoJSON geometry objects`_.
5556

5657
.. toctree::
5758
:maxdepth: 2

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ dependencies = [
9292
"verlib2==0.2",
9393
]
9494
[project.optional-dependencies]
95+
all = [
96+
"sqlalchemy-cratedb[vector]",
97+
]
9598
develop = [
9699
"black<25",
97100
"mypy<1.11",
@@ -116,6 +119,9 @@ test = [
116119
"pytest-cov<6",
117120
"pytest-mock<4",
118121
]
122+
vector = [
123+
"numpy",
124+
]
119125
[project.urls]
120126
changelog = "https://github.com/crate-workbench/sqlalchemy-cratedb/blob/main/CHANGES.md"
121127
documentation = "https://github.com/crate-workbench/sqlalchemy-cratedb"

src/sqlalchemy_cratedb/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .type.array import ObjectArray
2828
from .type.geo import Geopoint, Geoshape
2929
from .type.object import ObjectType
30+
from .type.vector import FloatVector
3031

3132
if SA_VERSION < SA_1_4:
3233
import textwrap
@@ -51,6 +52,7 @@
5152

5253
__all__ = [
5354
dialect,
55+
FloatVector,
5456
Geopoint,
5557
Geoshape,
5658
ObjectArray,

src/sqlalchemy_cratedb/compiler.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,12 @@ def visit_ARRAY(self, type_, **kw):
238238
def visit_OBJECT(self, type_, **kw):
239239
return "OBJECT"
240240

241+
def visit_FLOAT_VECTOR(self, type_, **kw):
242+
dimensions = type_.dimensions
243+
if dimensions is None:
244+
raise ValueError("FloatVector must be initialized with dimension size")
245+
return f"FLOAT_VECTOR({dimensions})"
246+
241247

242248
class CrateCompiler(compiler.SQLCompiler):
243249

src/sqlalchemy_cratedb/dialect.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
)
3434
from crate.client.exceptions import TimezoneUnawareException
3535
from .sa_version import SA_VERSION, SA_1_4, SA_2_0
36-
from .type import ObjectArray, ObjectType
36+
from .type import FloatVector, ObjectArray, ObjectType
3737

3838
TYPES_MAP = {
3939
"boolean": sqltypes.Boolean,
@@ -51,7 +51,8 @@
5151
"float": sqltypes.Float,
5252
"real": sqltypes.Float,
5353
"string": sqltypes.String,
54-
"text": sqltypes.String
54+
"text": sqltypes.String,
55+
"float_vector": FloatVector,
5556
}
5657
try:
5758
# SQLAlchemy >= 1.1
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .array import ObjectArray
22
from .geo import Geopoint, Geoshape
33
from .object import ObjectType
4+
from .vector import FloatVector

src/sqlalchemy_cratedb/type/vector.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
"""
2+
## About
3+
SQLAlchemy data type implementation for CrateDB's `FLOAT_VECTOR` type.
4+
5+
## References
6+
- https://crate.io/docs/crate/reference/en/master/general/ddl/data-types.html#float-vector
7+
- https://crate.io/docs/crate/reference/en/master/general/builtins/scalar-functions.html#scalar-knn-match
8+
9+
## Details
10+
The implementation is based on SQLAlchemy's `TypeDecorator`, and also
11+
offers compiler support.
12+
13+
## Notes
14+
CrateDB currently only supports the similarity function `VectorSimilarityFunction.EUCLIDEAN`.
15+
-- https://github.com/crate/crate/blob/5.5.1/server/src/main/java/io/crate/types/FloatVectorType.java#L55
16+
17+
On the other hand, pgvector use a comparator to apply different similarity
18+
functions as operators, see `pgvector.sqlalchemy.Vector.comparator_factory`.
19+
20+
<->: l2/euclidean_distance
21+
<#>: max_inner_product
22+
<=>: cosine_distance
23+
24+
## Backlog
25+
- The type implementation might want to be accompanied by corresponding support
26+
for the `KNN_MATCH` function, similar to what the dialect already offers for
27+
fulltext search through its `Match` predicate.
28+
29+
## Origin
30+
This module is based on the corresponding pgvector implementation
31+
by Andrew Kane. Thank you.
32+
33+
The MIT License (MIT)
34+
Copyright (c) 2021-2023 Andrew Kane
35+
https://github.com/pgvector/pgvector-python
36+
"""
37+
import typing as t
38+
39+
if t.TYPE_CHECKING:
40+
import numpy.typing as npt # pragma: no cover
41+
42+
import sqlalchemy as sa
43+
44+
45+
__all__ = [
46+
"from_db",
47+
"to_db",
48+
"FloatVector",
49+
]
50+
51+
52+
def from_db(value: t.Iterable) -> t.Optional[npt.ArrayLike]:
53+
import numpy as np
54+
55+
# from `pgvector.utils`
56+
# could be ndarray if already cast by lower-level driver
57+
if value is None or isinstance(value, np.ndarray):
58+
return value
59+
60+
return np.array(value, dtype=np.float32)
61+
62+
63+
def to_db(value: t.Any, dim: t.Optional[int] = None) -> t.Optional[t.List]:
64+
import numpy as np
65+
66+
# from `pgvector.utils`
67+
if value is None:
68+
return value
69+
70+
if isinstance(value, np.ndarray):
71+
if value.ndim != 1:
72+
raise ValueError("expected ndim to be 1")
73+
74+
if not np.issubdtype(value.dtype, np.integer) and not np.issubdtype(value.dtype, np.floating):
75+
raise ValueError("dtype must be numeric")
76+
77+
value = value.tolist()
78+
79+
if dim is not None and len(value) != dim:
80+
raise ValueError("expected %d dimensions, not %d" % (dim, len(value)))
81+
82+
return value
83+
84+
85+
class FloatVector(sa.TypeDecorator[t.Sequence[float]]):
86+
87+
"""
88+
SQLAlchemy `FloatVector` data type for CrateDB.
89+
"""
90+
91+
cache_ok = False
92+
93+
__visit_name__ = "FLOAT_VECTOR"
94+
95+
_is_array = True
96+
97+
zero_indexes = False
98+
99+
impl = sa.ARRAY
100+
101+
def __init__(self, dimensions: int = None):
102+
super().__init__(sa.FLOAT, dimensions=dimensions)
103+
104+
def as_generic(self, allow_nulltype=False):
105+
return sa.ARRAY(item_type=sa.FLOAT)
106+
107+
@property
108+
def python_type(self):
109+
return list
110+
111+
def bind_processor(self, dialect: sa.Dialect) -> t.Callable:
112+
def process(value: t.Iterable) -> t.Optional[t.List]:
113+
return to_db(value, self.dimensions)
114+
115+
return process
116+
117+
def result_processor(self, dialect: sa.Dialect, coltype: t.Any) -> t.Callable:
118+
def process(value: t.Any) -> t.Optional[npt.ArrayLike]:
119+
return from_db(value)
120+
121+
return process
122+
123+
124+
class KnnMatch(ColumnElement):
125+
"""
126+
Wrap CrateDB's `KNN_MATCH` function into an SQLAlchemy function.
127+
128+
https://cratedb.com/docs/crate/reference/en/latest/general/builtins/scalar-functions.html#scalar-knn-match
129+
"""
130+
inherit_cache = True
131+
132+
def __init__(self, column, term, k=None):
133+
super(KnnMatch, self).__init__()
134+
self.column = column
135+
self.term = term
136+
self.k = k
137+
138+
def compile_column(self, compiler):
139+
return compiler.process(self.column)
140+
141+
def compile_term(self, compiler):
142+
return compiler.process(literal(self.term))
143+
144+
def compile_k(self, compiler):
145+
return compiler.process(literal(self.k))
146+
147+
148+
def knn_match(column, term, k):
149+
"""
150+
Generate a match predicate for vector search.
151+
152+
:param column: A reference to a column or an index, or a subcolumn, or a
153+
dictionary of subcolumns with boost values.
154+
155+
:param term: The term to match against. This is an array of floating point
156+
values, which is compared to other vectors using a HNSW index.
157+
158+
:param k: The `k` argument determines the number of nearest neighbours to
159+
search in the index.
160+
"""
161+
return KnnMatch(column, term, k)
162+
163+
164+
@compiles(KnnMatch)
165+
def compile_knn_match(knn_match, compiler, **kwargs):
166+
"""
167+
Clause compiler for `knn_match`.
168+
"""
169+
return "knn_match(%s, %s, %s)" % (
170+
knn_match.compile_column(compiler),
171+
knn_match.compile_term(compiler),
172+
knn_match.compile_k(compiler),
173+
)

0 commit comments

Comments
 (0)