Skip to content

Commit 93fd272

Browse files
cleanup for numpy and base64 support
1 parent 226f22a commit 93fd272

File tree

4 files changed

+88
-84
lines changed

4 files changed

+88
-84
lines changed

elasticsearch/dsl/document_base.py

Lines changed: 71 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -356,73 +356,6 @@ def __init__(self, name: str, bases: Tuple[type, ...], attrs: Dict[str, Any]):
356356
value: Any = None
357357
required = None
358358
multi = None
359-
if name in annotations:
360-
# the field has a type annotation, so next we try to figure out
361-
# what field type we can use
362-
type_ = annotations[name]
363-
type_metadata = []
364-
if isinstance(type_, _AnnotatedAlias):
365-
type_metadata = type_.__metadata__
366-
type_ = type_.__origin__
367-
skip = False
368-
required = True
369-
multi = False
370-
while hasattr(type_, "__origin__"):
371-
if type_.__origin__ == ClassVar:
372-
skip = True
373-
break
374-
elif type_.__origin__ == Mapped:
375-
# M[type] -> extract the wrapped type
376-
type_ = type_.__args__[0]
377-
elif type_.__origin__ == Union:
378-
if len(type_.__args__) == 2 and type_.__args__[1] is type(None):
379-
# Optional[type] -> mark instance as optional
380-
required = False
381-
type_ = type_.__args__[0]
382-
else:
383-
raise TypeError("Unsupported union")
384-
elif type_.__origin__ in [list, List]:
385-
# List[type] -> mark instance as multi
386-
multi = True
387-
required = False
388-
type_ = type_.__args__[0]
389-
else:
390-
break
391-
if skip or type_ == ClassVar:
392-
# skip ClassVar attributes
393-
continue
394-
if type(type_) is UnionType:
395-
# a union given with the pipe syntax
396-
args = get_args(type_)
397-
if len(args) == 2 and args[1] is type(None):
398-
required = False
399-
type_ = type_.__args__[0]
400-
else:
401-
raise TypeError("Unsupported union")
402-
field = None
403-
field_args: List[Any] = []
404-
field_kwargs: Dict[str, Any] = {}
405-
if isinstance(type_, type) and issubclass(type_, InnerDoc):
406-
# object or nested field
407-
field = Nested if multi else Object
408-
field_args = [type_]
409-
elif type_ in self.type_annotation_map:
410-
# use best field type for the type hint provided
411-
field, field_kwargs = self.type_annotation_map[type_] # type: ignore[assignment]
412-
413-
# if this field does not have a right-hand value, we look in the metadata
414-
# of the annotation to see if we find it there
415-
for md in type_metadata:
416-
if isinstance(md, (_FieldMetadataDict, Field)):
417-
attrs[name] = md
418-
419-
if field:
420-
field_kwargs = {
421-
"multi": multi,
422-
"required": required,
423-
**field_kwargs,
424-
}
425-
value = field(*field_args, **field_kwargs)
426359

427360
if name in attrs:
428361
# this field has a right-side value, which can be field
@@ -448,6 +381,77 @@ def __init__(self, name: str, bases: Tuple[type, ...], attrs: Dict[str, Any]):
448381
if multi is not None:
449382
value._multi = multi
450383

384+
if value is None and name in annotations:
385+
# the field has a type annotation, so next we try to figure out
386+
# what field type we can use
387+
type_ = annotations[name]
388+
type_metadata = []
389+
if isinstance(type_, _AnnotatedAlias):
390+
type_metadata = type_.__metadata__
391+
type_ = type_.__origin__
392+
skip = False
393+
required = True
394+
multi = False
395+
396+
# if this field does not have a right-hand value, we look in the metadata
397+
# of the annotation to see if we find it there
398+
for md in type_metadata:
399+
if isinstance(md, (_FieldMetadataDict, Field)):
400+
attrs[name] = md
401+
value = md
402+
403+
if value is None:
404+
while hasattr(type_, "__origin__"):
405+
if type_.__origin__ == ClassVar:
406+
skip = True
407+
break
408+
elif type_.__origin__ == Mapped:
409+
# M[type] -> extract the wrapped type
410+
type_ = type_.__args__[0]
411+
elif type_.__origin__ == Union:
412+
if len(type_.__args__) == 2 and type_.__args__[1] is type(None):
413+
# Optional[type] -> mark instance as optional
414+
required = False
415+
type_ = type_.__args__[0]
416+
else:
417+
raise TypeError("Unsupported union")
418+
elif type_.__origin__ in [list, List]:
419+
# List[type] -> mark instance as multi
420+
multi = True
421+
required = False
422+
type_ = type_.__args__[0]
423+
else:
424+
break
425+
if skip or type_ == ClassVar:
426+
# skip ClassVar attributes
427+
continue
428+
if type(type_) is UnionType:
429+
# a union given with the pipe syntax
430+
args = get_args(type_)
431+
if len(args) == 2 and args[1] is type(None):
432+
required = False
433+
type_ = type_.__args__[0]
434+
else:
435+
raise TypeError("Unsupported union")
436+
field = None
437+
field_args: List[Any] = []
438+
field_kwargs: Dict[str, Any] = {}
439+
if isinstance(type_, type) and issubclass(type_, InnerDoc):
440+
# object or nested field
441+
field = Nested if multi else Object
442+
field_args = [type_]
443+
elif type_ in self.type_annotation_map:
444+
# use best field type for the type hint provided
445+
field, field_kwargs = self.type_annotation_map[type_] # type: ignore[assignment]
446+
447+
if field:
448+
field_kwargs = {
449+
"multi": multi,
450+
"required": required,
451+
**field_kwargs,
452+
}
453+
value = field(*field_args, **field_kwargs)
454+
451455
if value is None:
452456
raise TypeError(f"Cannot map field {name}")
453457

elasticsearch/dsl/field.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,11 @@ def deserialize(self, data: Any) -> Any:
165165
def clean(self, data: Any) -> Any:
166166
if data is not None:
167167
data = self.deserialize(data)
168-
if data in (None, [], {}) and self._required:
168+
# the "data is ..." comparisons below work well when data is a numpy
169+
# array (only for dense vector fields)
170+
# unfortunately numpy overrides the == operator in a way that causes
171+
# errors when used instead of "is"
172+
if (data is None or data is [] or data is {}) and self._required:
169173
raise ValidationException("Value required for this field.")
170174
return data
171175

@@ -1616,13 +1620,6 @@ def __init__(
16161620
kwargs["multi"] = True
16171621
super().__init__(*args, **kwargs)
16181622

1619-
def _deserialize(self, data: Any) -> Any:
1620-
if self._element_type == "float":
1621-
return float(data)
1622-
elif self._element_type == "byte":
1623-
return int(data)
1624-
return data
1625-
16261623

16271624
class Double(Float):
16281625
"""

elasticsearch/dsl/utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -609,11 +609,15 @@ def to_dict(self, skip_empty: bool = True) -> Dict[str, Any]:
609609
if isinstance(v, AttrList):
610610
v = v._l_
611611

612-
# if skip_empty:
613-
# # don't serialize empty values
614-
# # careful not to include numeric zeros
615-
# if v in ([], {}, None):
616-
# continue
612+
if skip_empty:
613+
# don't serialize empty values
614+
# careful not to include numeric zeros
615+
# the "data is ..." comparisons below work well when data is a numpy
616+
# array (only for dense vector fields)
617+
# unfortunately numpy overrides the == operator in a way that causes
618+
# errors when used instead of "is"
619+
if v is None or v is [] or v is {}:
620+
continue
617621

618622
out[k] = v
619623
return out

examples/quotes/backend/quotes.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from elasticsearch.dsl.pydantic import AsyncBaseESModel
1414
from elasticsearch import dsl
1515
from elasticsearch.dsl.types import DenseVectorIndexOptions
16+
from elasticsearch.helpers.vectors import numpy_array_to_base64_dense_vector
1617

1718
model = SentenceTransformer("all-MiniLM-L6-v2")
1819
dsl.async_connections.create_connection(hosts=[os.environ['ELASTICSEARCH_URL']], serializer=OrjsonSerializer()
@@ -23,7 +24,7 @@ class Quote(AsyncBaseESModel):
2324
quote: str
2425
author: Annotated[str, dsl.Keyword()]
2526
tags: Annotated[list[str], dsl.Keyword()]
26-
embedding: Annotated[list[float], dsl.DenseVector(
27+
embedding: Annotated[list[float] | str, dsl.DenseVector(
2728
index_options=DenseVectorIndexOptions(type="flat"),
2829
)] = Field(init=False, default=[])
2930

@@ -141,9 +142,7 @@ def embed_quotes(quotes):
141142
embeddings = model.encode([q.quote for q in quotes])
142143
for q, e in zip(quotes, embeddings):
143144
q.embedding = e
144-
# q.embedding = e.tolist()
145-
##byte_array = e.byteswap().tobytes()
146-
##q.embedding = base64.b64encode(byte_array).decode()
145+
q.embedding = numpy_array_to_base64_dense_vector(e)
147146

148147

149148
async def ingest_quotes():

0 commit comments

Comments
 (0)