Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 73 additions & 67 deletions elasticsearch/dsl/document_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,73 +356,6 @@ def __init__(self, name: str, bases: Tuple[type, ...], attrs: Dict[str, Any]):
value: Any = None
required = None
multi = None
if name in annotations:
# the field has a type annotation, so next we try to figure out
# what field type we can use
type_ = annotations[name]
type_metadata = []
if isinstance(type_, _AnnotatedAlias):
type_metadata = type_.__metadata__
type_ = type_.__origin__
skip = False
required = True
multi = False
while hasattr(type_, "__origin__"):
if type_.__origin__ == ClassVar:
skip = True
break
elif type_.__origin__ == Mapped:
# M[type] -> extract the wrapped type
type_ = type_.__args__[0]
elif type_.__origin__ == Union:
if len(type_.__args__) == 2 and type_.__args__[1] is type(None):
# Optional[type] -> mark instance as optional
required = False
type_ = type_.__args__[0]
else:
raise TypeError("Unsupported union")
elif type_.__origin__ in [list, List]:
# List[type] -> mark instance as multi
multi = True
required = False
type_ = type_.__args__[0]
else:
break
if skip or type_ == ClassVar:
# skip ClassVar attributes
continue
if type(type_) is UnionType:
# a union given with the pipe syntax
args = get_args(type_)
if len(args) == 2 and args[1] is type(None):
required = False
type_ = type_.__args__[0]
else:
raise TypeError("Unsupported union")
field = None
field_args: List[Any] = []
field_kwargs: Dict[str, Any] = {}
if isinstance(type_, type) and issubclass(type_, InnerDoc):
# object or nested field
field = Nested if multi else Object
field_args = [type_]
elif type_ in self.type_annotation_map:
# use best field type for the type hint provided
field, field_kwargs = self.type_annotation_map[type_] # type: ignore[assignment]

# if this field does not have a right-hand value, we look in the metadata
# of the annotation to see if we find it there
for md in type_metadata:
if isinstance(md, (_FieldMetadataDict, Field)):
attrs[name] = md

if field:
field_kwargs = {
"multi": multi,
"required": required,
**field_kwargs,
}
value = field(*field_args, **field_kwargs)

if name in attrs:
# this field has a right-side value, which can be field
Expand All @@ -448,6 +381,79 @@ def __init__(self, name: str, bases: Tuple[type, ...], attrs: Dict[str, Any]):
if multi is not None:
value._multi = multi

if value is None and name in annotations:
# the field has a type annotation, so next we try to figure out
# what field type we can use
type_ = annotations[name]
type_metadata = []
if isinstance(type_, _AnnotatedAlias):
type_metadata = type_.__metadata__
type_ = type_.__origin__
skip = False
required = True
multi = False

# if this field does not have a right-hand value, we look in the metadata
# of the annotation to see if we find it there
for md in type_metadata:
if isinstance(md, (_FieldMetadataDict, Field)):
attrs[name] = md
value = md

if value is None:
while hasattr(type_, "__origin__"):
if type_.__origin__ == ClassVar:
skip = True
break
elif type_.__origin__ == Mapped:
# M[type] -> extract the wrapped type
type_ = type_.__args__[0]
elif type_.__origin__ == Union:
if len(type_.__args__) == 2 and type_.__args__[1] is type(
None
):
# Optional[type] -> mark instance as optional
required = False
type_ = type_.__args__[0]
else:
raise TypeError("Unsupported union")
elif type_.__origin__ in [list, List]:
# List[type] -> mark instance as multi
multi = True
required = False
type_ = type_.__args__[0]
else:
break
if skip or type_ == ClassVar:
# skip ClassVar attributes
continue
if type(type_) is UnionType:
# a union given with the pipe syntax
args = get_args(type_)
if len(args) == 2 and args[1] is type(None):
required = False
type_ = type_.__args__[0]
else:
raise TypeError("Unsupported union")
field = None
field_args: List[Any] = []
field_kwargs: Dict[str, Any] = {}
if isinstance(type_, type) and issubclass(type_, InnerDoc):
# object or nested field
field = Nested if multi else Object
field_args = [type_]
elif type_ in self.type_annotation_map:
# use best field type for the type hint provided
field, field_kwargs = self.type_annotation_map[type_] # type: ignore[assignment]

if field:
field_kwargs = {
"multi": multi,
"required": required,
**field_kwargs,
}
value = field(*field_args, **field_kwargs)

if value is None:
raise TypeError(f"Cannot map field {name}")

Expand Down
13 changes: 5 additions & 8 deletions elasticsearch/dsl/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,11 @@ def deserialize(self, data: Any) -> Any:
def clean(self, data: Any) -> Any:
if data is not None:
data = self.deserialize(data)
if data in (None, [], {}) and self._required:
# the "data is ..." comparisons below work well when data is a numpy
# array (only for dense vector fields)
# unfortunately numpy overrides the == operator in a way that causes
# errors when used instead of "is"
if (data is None or data is [] or data is {}) and self._required:
raise ValidationException("Value required for this field.")
return data

Expand Down Expand Up @@ -1616,13 +1620,6 @@ def __init__(
kwargs["multi"] = True
super().__init__(*args, **kwargs)

def _deserialize(self, data: Any) -> Any:
if self._element_type == "float":
return float(data)
elif self._element_type == "byte":
return int(data)
return data


class Double(Float):
"""
Expand Down
6 changes: 5 additions & 1 deletion elasticsearch/dsl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,11 @@ def to_dict(self, skip_empty: bool = True) -> Dict[str, Any]:
if skip_empty:
# don't serialize empty values
# careful not to include numeric zeros
if v in ([], {}, None):
# the "data is ..." comparisons below work well when data is a numpy
# array (only for dense vector fields)
# unfortunately numpy overrides the == operator in a way that causes
# errors when used instead of "is"
if v is None or v is [] or v is {}:
continue

out[k] = v
Expand Down
13 changes: 9 additions & 4 deletions examples/quotes/backend/quotes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import base64
import csv
import os
from time import time
Expand All @@ -8,19 +9,22 @@
from pydantic import BaseModel, Field, ValidationError
from sentence_transformers import SentenceTransformer

from elasticsearch import NotFoundError
from elasticsearch import NotFoundError, OrjsonSerializer
from elasticsearch.dsl.pydantic import AsyncBaseESModel
from elasticsearch import dsl
from elasticsearch.dsl.types import DenseVectorIndexOptions
from elasticsearch.helpers.vectors import numpy_array_to_base64_dense_vector

model = SentenceTransformer("all-MiniLM-L6-v2")
dsl.async_connections.create_connection(hosts=[os.environ['ELASTICSEARCH_URL']])
dsl.async_connections.create_connection(hosts=[os.environ['ELASTICSEARCH_URL']], serializer=OrjsonSerializer()
)


class Quote(AsyncBaseESModel):
quote: str
author: Annotated[str, dsl.Keyword()]
tags: Annotated[list[str], dsl.Keyword()]
embedding: Annotated[list[float], dsl.DenseVector()] = Field(init=False, default=[])
embedding: Annotated[list[float] | str, dsl.DenseVector()] = Field(init=False, default=[])

class Index:
name = 'quotes'
Expand Down Expand Up @@ -135,7 +139,8 @@ async def search_quotes(req: SearchRequest) -> SearchResponse:
def embed_quotes(quotes):
embeddings = model.encode([q.quote for q in quotes])
for q, e in zip(quotes, embeddings):
q.embedding = e.tolist()
q.embedding = e
q.embedding = numpy_array_to_base64_dense_vector(e)


async def ingest_quotes():
Expand Down
13 changes: 5 additions & 8 deletions utils/templates/field.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,11 @@ class Field(DslBase):
def clean(self, data: Any) -> Any:
if data is not None:
data = self.deserialize(data)
if data in (None, [], {}) and self._required:
# the "data is ..." comparisons below work well when data is a numpy
# array (only for dense vector fields)
# unfortunately numpy overrides the == operator in a way that causes
# errors when used instead of "is"
if (data is None or data is [] or data is {}) and self._required:
raise ValidationException("Value required for this field.")
return data

Expand Down Expand Up @@ -417,13 +421,6 @@ class {{ k.name }}({{ k.parent }}):
if self._element_type in ["float", "byte"]:
kwargs["multi"] = True
super().__init__(*args, **kwargs)

def _deserialize(self, data: Any) -> Any:
if self._element_type == "float":
return float(data)
elif self._element_type == "byte":
return int(data)
return data
{% elif k.field == "scaled_float" %}
if 'scaling_factor' not in kwargs:
if len(args) > 0:
Expand Down
Loading