|
1 | 1 | import datetime |
2 | 2 | from decimal import Decimal |
3 | | -from sqlalchemy import func |
| 3 | +from sqlalchemy import func, text |
4 | 4 | from sqlalchemy.sql import sqltypes |
5 | | -from sqlalchemy.types import UserDefinedType |
| 5 | +from sqlalchemy.types import UserDefinedType, Float |
6 | 6 | from uuid import UUID as _python_UUID |
7 | 7 | from intersystems_iris import IRISList |
8 | 8 |
|
@@ -247,6 +247,72 @@ def func(self, funcname: str, other): |
247 | 247 | return getattr(func, funcname)(self, irislist.getBuffer()) |
248 | 248 |
|
249 | 249 |
|
| 250 | +class IRISVector(UserDefinedType): |
| 251 | + cache_ok = True |
| 252 | + |
| 253 | + def __init__(self, max_items: int = None, item_type: type = float): |
| 254 | + super(UserDefinedType, self).__init__() |
| 255 | + if item_type not in [float, int, Decimal]: |
| 256 | + raise TypeError( |
| 257 | + f"IRISVector expected int, float or Decimal; got {type.__name__}; expected: int, float, Decimal" |
| 258 | + ) |
| 259 | + self.max_items = max_items |
| 260 | + self.item_type = item_type |
| 261 | + item_type_server = ( |
| 262 | + "decimal" |
| 263 | + if self.item_type is float |
| 264 | + else "float" |
| 265 | + if self.item_type is Decimal |
| 266 | + else "int" |
| 267 | + ) |
| 268 | + self.item_type_server = item_type_server |
| 269 | + |
| 270 | + def get_col_spec(self, **kw): |
| 271 | + if self.max_items is None and self.item_type is None: |
| 272 | + return "VECTOR" |
| 273 | + len = str(self.max_items or "") |
| 274 | + return f"VECTOR({self.item_type_server}, {len})" |
| 275 | + |
| 276 | + def bind_processor(self, dialect): |
| 277 | + def process(value): |
| 278 | + if not value: |
| 279 | + return value |
| 280 | + if not isinstance(value, list) and not isinstance(value, tuple): |
| 281 | + raise ValueError("expected list or tuple, got '%s'" % type(value)) |
| 282 | + return f"[{','.join([str(v) for v in value])}]" |
| 283 | + |
| 284 | + return process |
| 285 | + |
| 286 | + def result_processor(self, dialect, coltype): |
| 287 | + def process(value): |
| 288 | + if not value: |
| 289 | + return value |
| 290 | + vals = value.split(",") |
| 291 | + vals = [self.item_type(v) for v in vals] |
| 292 | + return vals |
| 293 | + |
| 294 | + return process |
| 295 | + |
| 296 | + class comparator_factory(UserDefinedType.Comparator): |
| 297 | + # def l2_distance(self, other): |
| 298 | + # return self.func('vector_l2', other) |
| 299 | + |
| 300 | + def max_inner_product(self, other): |
| 301 | + return self.func('vector_dot_product', other) |
| 302 | + |
| 303 | + def cosine_distance(self, other): |
| 304 | + return self.func('vector_cosine', other) |
| 305 | + |
| 306 | + def cosine(self, other): |
| 307 | + return (1 - self.func('vector_cosine', other)) |
| 308 | + |
| 309 | + def func(self, funcname: str, other): |
| 310 | + if not isinstance(other, list) and not isinstance(other, tuple): |
| 311 | + raise ValueError("expected list or tuple, got '%s'" % type(other)) |
| 312 | + othervalue = f"[{','.join([str(v) for v in other])}]" |
| 313 | + return getattr(func, funcname)(self, func.to_vector(othervalue, text(self.type.item_type_server))) |
| 314 | + |
| 315 | + |
250 | 316 | class BIT(sqltypes.TypeEngine): |
251 | 317 | __visit_name__ = "BIT" |
252 | 318 |
|
|
0 commit comments