diff --git a/elasticsearch/dsl/document_base.py b/elasticsearch/dsl/document_base.py index 626179747..152579e0b 100644 --- a/elasticsearch/dsl/document_base.py +++ b/elasticsearch/dsl/document_base.py @@ -356,73 +356,6 @@ def __init__(self, name: str, bases: Tuple[type, ...], attrs: Dict[str, Any]): value: Any = None required = None multi = None - if name in annotations: - # the field has a type annotation, so next we try to figure out - # what field type we can use - type_ = annotations[name] - type_metadata = [] - if isinstance(type_, _AnnotatedAlias): - type_metadata = type_.__metadata__ - type_ = type_.__origin__ - skip = False - required = True - multi = False - while hasattr(type_, "__origin__"): - if type_.__origin__ == ClassVar: - skip = True - break - elif type_.__origin__ == Mapped: - # M[type] -> extract the wrapped type - type_ = type_.__args__[0] - elif type_.__origin__ == Union: - if len(type_.__args__) == 2 and type_.__args__[1] is type(None): - # Optional[type] -> mark instance as optional - required = False - type_ = type_.__args__[0] - else: - raise TypeError("Unsupported union") - elif type_.__origin__ in [list, List]: - # List[type] -> mark instance as multi - multi = True - required = False - type_ = type_.__args__[0] - else: - break - if skip or type_ == ClassVar: - # skip ClassVar attributes - continue - if type(type_) is UnionType: - # a union given with the pipe syntax - args = get_args(type_) - if len(args) == 2 and args[1] is type(None): - required = False - type_ = type_.__args__[0] - else: - raise TypeError("Unsupported union") - field = None - field_args: List[Any] = [] - field_kwargs: Dict[str, Any] = {} - if isinstance(type_, type) and issubclass(type_, InnerDoc): - # object or nested field - field = Nested if multi else Object - field_args = [type_] - elif type_ in self.type_annotation_map: - # use best field type for the type hint provided - field, field_kwargs = self.type_annotation_map[type_] # type: ignore[assignment] - - # if this field does not have a right-hand value, we look in the metadata - # of the annotation to see if we find it there - for md in type_metadata: - if isinstance(md, (_FieldMetadataDict, Field)): - attrs[name] = md - - if field: - field_kwargs = { - "multi": multi, - "required": required, - **field_kwargs, - } - value = field(*field_args, **field_kwargs) if name in attrs: # this field has a right-side value, which can be field @@ -448,6 +381,79 @@ def __init__(self, name: str, bases: Tuple[type, ...], attrs: Dict[str, Any]): if multi is not None: value._multi = multi + if value is None and name in annotations: + # the field has a type annotation, so next we try to figure out + # what field type we can use + type_ = annotations[name] + type_metadata = [] + if isinstance(type_, _AnnotatedAlias): + type_metadata = type_.__metadata__ + type_ = type_.__origin__ + skip = False + required = True + multi = False + + # if this field does not have a right-hand value, we look in the metadata + # of the annotation to see if we find it there + for md in type_metadata: + if isinstance(md, (_FieldMetadataDict, Field)): + attrs[name] = md + value = md + + if value is None: + while hasattr(type_, "__origin__"): + if type_.__origin__ == ClassVar: + skip = True + break + elif type_.__origin__ == Mapped: + # M[type] -> extract the wrapped type + type_ = type_.__args__[0] + elif type_.__origin__ == Union: + if len(type_.__args__) == 2 and type_.__args__[1] is type( + None + ): + # Optional[type] -> mark instance as optional + required = False + type_ = type_.__args__[0] + else: + raise TypeError("Unsupported union") + elif type_.__origin__ in [list, List]: + # List[type] -> mark instance as multi + multi = True + required = False + type_ = type_.__args__[0] + else: + break + if skip or type_ == ClassVar: + # skip ClassVar attributes + continue + if type(type_) is UnionType: + # a union given with the pipe syntax + args = get_args(type_) + if len(args) == 2 and args[1] is type(None): + required = False + type_ = type_.__args__[0] + else: + raise TypeError("Unsupported union") + field = None + field_args: List[Any] = [] + field_kwargs: Dict[str, Any] = {} + if isinstance(type_, type) and issubclass(type_, InnerDoc): + # object or nested field + field = Nested if multi else Object + field_args = [type_] + elif type_ in self.type_annotation_map: + # use best field type for the type hint provided + field, field_kwargs = self.type_annotation_map[type_] # type: ignore[assignment] + + if field: + field_kwargs = { + "multi": multi, + "required": required, + **field_kwargs, + } + value = field(*field_args, **field_kwargs) + if value is None: raise TypeError(f"Cannot map field {name}") diff --git a/elasticsearch/dsl/field.py b/elasticsearch/dsl/field.py index 3b5075287..8ba88d22d 100644 --- a/elasticsearch/dsl/field.py +++ b/elasticsearch/dsl/field.py @@ -165,7 +165,11 @@ def deserialize(self, data: Any) -> Any: def clean(self, data: Any) -> Any: if data is not None: data = self.deserialize(data) - if data in (None, [], {}) and self._required: + # the "data is ..." comparisons below work well when data is a numpy + # array (only for dense vector fields) + # unfortunately numpy overrides the == operator in a way that causes + # errors when used instead of "is" + if (data is None or data is [] or data is {}) and self._required: raise ValidationException("Value required for this field.") return data @@ -1616,13 +1620,6 @@ def __init__( kwargs["multi"] = True super().__init__(*args, **kwargs) - def _deserialize(self, data: Any) -> Any: - if self._element_type == "float": - return float(data) - elif self._element_type == "byte": - return int(data) - return data - class Double(Float): """ diff --git a/elasticsearch/dsl/utils.py b/elasticsearch/dsl/utils.py index cce3c052c..9c9bda12e 100644 --- a/elasticsearch/dsl/utils.py +++ b/elasticsearch/dsl/utils.py @@ -612,7 +612,11 @@ def to_dict(self, skip_empty: bool = True) -> Dict[str, Any]: if skip_empty: # don't serialize empty values # careful not to include numeric zeros - if v in ([], {}, None): + # the "data is ..." comparisons below work well when data is a numpy + # array (only for dense vector fields) + # unfortunately numpy overrides the == operator in a way that causes + # errors when used instead of "is" + if v is None or v is [] or v is {}: continue out[k] = v diff --git a/examples/quotes/backend/quotes.py b/examples/quotes/backend/quotes.py index 4492d5e7e..c113b9239 100644 --- a/examples/quotes/backend/quotes.py +++ b/examples/quotes/backend/quotes.py @@ -1,4 +1,5 @@ import asyncio +import base64 import csv import os from time import time @@ -8,19 +9,22 @@ from pydantic import BaseModel, Field, ValidationError from sentence_transformers import SentenceTransformer -from elasticsearch import NotFoundError +from elasticsearch import NotFoundError, OrjsonSerializer from elasticsearch.dsl.pydantic import AsyncBaseESModel from elasticsearch import dsl +from elasticsearch.dsl.types import DenseVectorIndexOptions +from elasticsearch.helpers.vectors import numpy_array_to_base64_dense_vector model = SentenceTransformer("all-MiniLM-L6-v2") -dsl.async_connections.create_connection(hosts=[os.environ['ELASTICSEARCH_URL']]) +dsl.async_connections.create_connection(hosts=[os.environ['ELASTICSEARCH_URL']], serializer=OrjsonSerializer() +) class Quote(AsyncBaseESModel): quote: str author: Annotated[str, dsl.Keyword()] tags: Annotated[list[str], dsl.Keyword()] - embedding: Annotated[list[float], dsl.DenseVector()] = Field(init=False, default=[]) + embedding: Annotated[list[float] | str, dsl.DenseVector()] = Field(init=False, default=[]) class Index: name = 'quotes' @@ -135,7 +139,8 @@ async def search_quotes(req: SearchRequest) -> SearchResponse: def embed_quotes(quotes): embeddings = model.encode([q.quote for q in quotes]) for q, e in zip(quotes, embeddings): - q.embedding = e.tolist() + q.embedding = e + q.embedding = numpy_array_to_base64_dense_vector(e) async def ingest_quotes(): diff --git a/utils/templates/field.py.tpl b/utils/templates/field.py.tpl index 43df1b5f0..c965e484a 100644 --- a/utils/templates/field.py.tpl +++ b/utils/templates/field.py.tpl @@ -159,7 +159,11 @@ class Field(DslBase): def clean(self, data: Any) -> Any: if data is not None: data = self.deserialize(data) - if data in (None, [], {}) and self._required: + # the "data is ..." comparisons below work well when data is a numpy + # array (only for dense vector fields) + # unfortunately numpy overrides the == operator in a way that causes + # errors when used instead of "is" + if (data is None or data is [] or data is {}) and self._required: raise ValidationException("Value required for this field.") return data @@ -417,13 +421,6 @@ class {{ k.name }}({{ k.parent }}): if self._element_type in ["float", "byte"]: kwargs["multi"] = True super().__init__(*args, **kwargs) - - def _deserialize(self, data: Any) -> Any: - if self._element_type == "float": - return float(data) - elif self._element_type == "byte": - return int(data) - return data {% elif k.field == "scaled_float" %} if 'scaling_factor' not in kwargs: if len(args) > 0: