Skip to content

Commit aa7de6e

Browse files
committed
Vector: Fix type checking and compatibility with SQLAlchemy 1.x
1 parent 1f30b4b commit aa7de6e

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

src/sqlalchemy_cratedb/type/vector.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
- The type implementation might want to be accompanied by corresponding support
2626
for the `KNN_MATCH` function, similar to what the dialect already offers for
2727
fulltext search through its `Match` predicate.
28+
- After dropping support for SQLAlchemy 1.3, use
29+
`class FloatVector(sa.TypeDecorator[t.Sequence[float]]):`
2830
2931
## Origin
3032
This module is based on the corresponding pgvector implementation
@@ -49,7 +51,7 @@
4951
]
5052

5153

52-
def from_db(value: t.Iterable) -> t.Optional[npt.ArrayLike]:
54+
def from_db(value: t.Iterable) -> t.Optional["npt.ArrayLike"]:
5355
import numpy as np
5456

5557
# from `pgvector.utils`
@@ -82,8 +84,7 @@ def to_db(value: t.Any, dim: t.Optional[int] = None) -> t.Optional[t.List]:
8284
return value
8385

8486

85-
class FloatVector(sa.TypeDecorator[t.Sequence[float]]):
86-
87+
class FloatVector(sa.TypeDecorator):
8788
"""
8889
SQLAlchemy `FloatVector` data type for CrateDB.
8990
"""
@@ -108,14 +109,14 @@ def as_generic(self, allow_nulltype=False):
108109
def python_type(self):
109110
return list
110111

111-
def bind_processor(self, dialect: sa.Dialect) -> t.Callable:
112+
def bind_processor(self, dialect: sa.engine.Dialect) -> t.Callable:
112113
def process(value: t.Iterable) -> t.Optional[t.List]:
113114
return to_db(value, self.dimensions)
114115

115116
return process
116117

117-
def result_processor(self, dialect: sa.Dialect, coltype: t.Any) -> t.Callable:
118-
def process(value: t.Any) -> t.Optional[npt.ArrayLike]:
118+
def result_processor(self, dialect: sa.engine.Dialect, coltype: t.Any) -> t.Callable:
119+
def process(value: t.Any) -> t.Optional["npt.ArrayLike"]:
119120
return from_db(value)
120121

121122
return process

0 commit comments

Comments
 (0)