Skip to content

Commit 4dd03f8

Browse files
Support Annotated typing hint
1 parent 3d47b77 commit 4dd03f8

File tree

3 files changed

+269
-207
lines changed

3 files changed

+269
-207
lines changed

elasticsearch/dsl/document_base.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
get_args,
3434
overload,
3535
)
36+
from typing_extensions import _AnnotatedAlias
3637

3738
try:
3839
from types import UnionType
@@ -343,6 +344,10 @@ def __init__(self, name: str, bases: Tuple[type, ...], attrs: Dict[str, Any]):
343344
# the field has a type annotation, so next we try to figure out
344345
# what field type we can use
345346
type_ = annotations[name]
347+
type_metadata = []
348+
if isinstance(type_, _AnnotatedAlias):
349+
type_metadata = type_.__metadata__
350+
type_ = type_.__origin__
346351
skip = False
347352
required = True
348353
multi = False
@@ -389,6 +394,13 @@ def __init__(self, name: str, bases: Tuple[type, ...], attrs: Dict[str, Any]):
389394
# use best field type for the type hint provided
390395
field, field_kwargs = self.type_annotation_map[type_] # type: ignore[assignment]
391396

397+
if name not in attrs:
398+
# if this field does not have a right-hand value, we look in the metadata
399+
# of the annotation to see if we find it there
400+
for md in type_metadata:
401+
if isinstance(md, (_FieldMetadataDict, Field)):
402+
attrs[name] = md
403+
392404
if field:
393405
field_kwargs = {
394406
"multi": multi,
@@ -401,7 +413,7 @@ def __init__(self, name: str, bases: Tuple[type, ...], attrs: Dict[str, Any]):
401413
# this field has a right-side value, which can be field
402414
# instance on its own or wrapped with mapped_field()
403415
attr_value = attrs[name]
404-
if isinstance(attr_value, dict):
416+
if isinstance(attr_value, _FieldMetadataDict):
405417
# the mapped_field() wrapper function was used so we need
406418
# to look for the field instance and also record any
407419
# dataclass-style defaults
@@ -490,6 +502,12 @@ def __delete__(self, instance: Any) -> None: ...
490502
M = Mapped
491503

492504

505+
class _FieldMetadataDict(dict):
506+
"""This class is used to identify metadata returned by the `mapped_field()` function."""
507+
508+
pass
509+
510+
493511
def mapped_field(
494512
field: Optional[Field] = None,
495513
*,
@@ -514,13 +532,13 @@ def mapped_field(
514532
when one isn't provided explicitly. Only one of ``factory`` and
515533
``default_factory`` can be used.
516534
"""
517-
return {
518-
"_field": field,
519-
"init": init,
520-
"default": default,
521-
"default_factory": default_factory,
535+
return _FieldMetadataDict(
536+
_field=field,
537+
init=init,
538+
default=default,
539+
default_factory=default_factory,
522540
**kwargs,
523-
}
541+
)
524542

525543

526544
@dataclass_transform(field_specifiers=(mapped_field,))

test_elasticsearch/test_dsl/_async/test_document.py

Lines changed: 122 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import sys
2828
from datetime import datetime
2929
from hashlib import md5
30-
from typing import Any, ClassVar, Dict, List, Optional
30+
from typing import Annotated, Any, ClassVar, Dict, List, Optional
3131

3232
import pytest
3333
from pytest import raises
@@ -530,7 +530,7 @@ def test_document_inheritance() -> None:
530530
} == MySubDoc._doc_type.mapping.to_dict()
531531

532532

533-
def test_child_class_can_override_parent() -> None:
533+
def test_childdoc_class_can_override_parent() -> None:
534534
class A(AsyncDocument):
535535
o = field.Object(dynamic=False, properties={"a": field.Text()})
536536

@@ -679,117 +679,139 @@ class TypedDoc(AsyncDocument):
679679
i1: ClassVar
680680
i2: ClassVar[int]
681681

682-
props = TypedDoc._doc_type.mapping.to_dict()["properties"]
683-
assert props == {
684-
"st": {"type": "text"},
685-
"dt": {"type": "date"},
686-
"li": {"type": "integer"},
687-
"ob": {
688-
"type": "object",
689-
"properties": {
690-
"st": {"type": "text"},
691-
"dt": {"type": "date"},
692-
"li": {"type": "integer"},
682+
class TypedDocAnnotated(AsyncDocument):
683+
st: Annotated[str, "foo"]
684+
dt: Annotated[Optional[datetime], "bar"]
685+
li: Annotated[List[int], "baz"]
686+
ob: Annotated[TypedInnerDoc, "qux"]
687+
ns: Annotated[List[TypedInnerDoc], "quux"]
688+
ip: Annotated[Optional[str], field.Ip()]
689+
k1: Annotated[str, field.Keyword(required=True)]
690+
k2: Annotated[M[str], field.Keyword()]
691+
k3: Annotated[str, mapped_field(field.Keyword(), default="foo")]
692+
k4: Annotated[M[Optional[str]], mapped_field(field.Keyword())] # type: ignore[misc]
693+
s1: Annotated[Secret, SecretField()]
694+
s2: Annotated[M[Secret], SecretField()]
695+
s3: Annotated[Secret, mapped_field(SecretField())] # type: ignore[misc]
696+
s4: Annotated[
697+
M[Optional[Secret]],
698+
mapped_field(SecretField(), default_factory=lambda: "foo"),
699+
]
700+
i1: Annotated[ClassVar, "classvar"]
701+
i2: Annotated[ClassVar[int], "classvar"]
702+
703+
for doc_class in [TypedDoc, TypedDocAnnotated]:
704+
props = doc_class._doc_type.mapping.to_dict()["properties"]
705+
assert props == {
706+
"st": {"type": "text"},
707+
"dt": {"type": "date"},
708+
"li": {"type": "integer"},
709+
"ob": {
710+
"type": "object",
711+
"properties": {
712+
"st": {"type": "text"},
713+
"dt": {"type": "date"},
714+
"li": {"type": "integer"},
715+
},
693716
},
694-
},
695-
"ns": {
696-
"type": "nested",
697-
"properties": {
698-
"st": {"type": "text"},
699-
"dt": {"type": "date"},
700-
"li": {"type": "integer"},
717+
"ns": {
718+
"type": "nested",
719+
"properties": {
720+
"st": {"type": "text"},
721+
"dt": {"type": "date"},
722+
"li": {"type": "integer"},
723+
},
701724
},
702-
},
703-
"ip": {"type": "ip"},
704-
"k1": {"type": "keyword"},
705-
"k2": {"type": "keyword"},
706-
"k3": {"type": "keyword"},
707-
"k4": {"type": "keyword"},
708-
"s1": {"type": "text"},
709-
"s2": {"type": "text"},
710-
"s3": {"type": "text"},
711-
"s4": {"type": "text"},
712-
}
725+
"ip": {"type": "ip"},
726+
"k1": {"type": "keyword"},
727+
"k2": {"type": "keyword"},
728+
"k3": {"type": "keyword"},
729+
"k4": {"type": "keyword"},
730+
"s1": {"type": "text"},
731+
"s2": {"type": "text"},
732+
"s3": {"type": "text"},
733+
"s4": {"type": "text"},
734+
}
713735

714-
TypedDoc.i1 = "foo"
715-
TypedDoc.i2 = 123
736+
doc_class.i1 = "foo"
737+
doc_class.i2 = 123
738+
739+
doc = doc_class()
740+
assert doc.k3 == "foo"
741+
assert doc.s4 == "foo"
742+
with raises(ValidationException) as exc_info:
743+
doc.full_clean()
744+
assert set(exc_info.value.args[0].keys()) == {
745+
"st",
746+
"k1",
747+
"k2",
748+
"ob",
749+
"s1",
750+
"s2",
751+
"s3",
752+
}
716753

717-
doc = TypedDoc()
718-
assert doc.k3 == "foo"
719-
assert doc.s4 == "foo"
720-
with raises(ValidationException) as exc_info:
754+
assert doc_class.i1 == "foo"
755+
assert doc_class.i2 == 123
756+
757+
doc.st = "s"
758+
doc.li = [1, 2, 3]
759+
doc.k1 = "k1"
760+
doc.k2 = "k2"
761+
doc.ob.st = "s"
762+
doc.ob.li = [1]
763+
doc.s1 = "s1"
764+
doc.s2 = "s2"
765+
doc.s3 = "s3"
721766
doc.full_clean()
722-
assert set(exc_info.value.args[0].keys()) == {
723-
"st",
724-
"k1",
725-
"k2",
726-
"ob",
727-
"s1",
728-
"s2",
729-
"s3",
730-
}
731767

732-
assert TypedDoc.i1 == "foo"
733-
assert TypedDoc.i2 == 123
734-
735-
doc.st = "s"
736-
doc.li = [1, 2, 3]
737-
doc.k1 = "k1"
738-
doc.k2 = "k2"
739-
doc.ob.st = "s"
740-
doc.ob.li = [1]
741-
doc.s1 = "s1"
742-
doc.s2 = "s2"
743-
doc.s3 = "s3"
744-
doc.full_clean()
768+
doc.ob = TypedInnerDoc(li=[1])
769+
with raises(ValidationException) as exc_info:
770+
doc.full_clean()
771+
assert set(exc_info.value.args[0].keys()) == {"ob"}
772+
assert set(exc_info.value.args[0]["ob"][0].args[0].keys()) == {"st"}
745773

746-
doc.ob = TypedInnerDoc(li=[1])
747-
with raises(ValidationException) as exc_info:
748-
doc.full_clean()
749-
assert set(exc_info.value.args[0].keys()) == {"ob"}
750-
assert set(exc_info.value.args[0]["ob"][0].args[0].keys()) == {"st"}
774+
doc.ob.st = "s"
775+
doc.ns.append(TypedInnerDoc(li=[1, 2]))
776+
with raises(ValidationException) as exc_info:
777+
doc.full_clean()
751778

752-
doc.ob.st = "s"
753-
doc.ns.append(TypedInnerDoc(li=[1, 2]))
754-
with raises(ValidationException) as exc_info:
779+
doc.ns[0].st = "s"
755780
doc.full_clean()
756781

757-
doc.ns[0].st = "s"
758-
doc.full_clean()
759-
760-
doc.ip = "1.2.3.4"
761-
n = datetime.now()
762-
doc.dt = n
763-
assert doc.to_dict() == {
764-
"st": "s",
765-
"li": [1, 2, 3],
766-
"dt": n,
767-
"ob": {
782+
doc.ip = "1.2.3.4"
783+
n = datetime.now()
784+
doc.dt = n
785+
assert doc.to_dict() == {
768786
"st": "s",
769-
"li": [1],
770-
},
771-
"ns": [
772-
{
787+
"li": [1, 2, 3],
788+
"dt": n,
789+
"ob": {
773790
"st": "s",
774-
"li": [1, 2],
775-
}
776-
],
777-
"ip": "1.2.3.4",
778-
"k1": "k1",
779-
"k2": "k2",
780-
"k3": "foo",
781-
"s1": "s1",
782-
"s2": "s2",
783-
"s3": "s3",
784-
"s4": "foo",
785-
}
791+
"li": [1],
792+
},
793+
"ns": [
794+
{
795+
"st": "s",
796+
"li": [1, 2],
797+
}
798+
],
799+
"ip": "1.2.3.4",
800+
"k1": "k1",
801+
"k2": "k2",
802+
"k3": "foo",
803+
"s1": "s1",
804+
"s2": "s2",
805+
"s3": "s3",
806+
"s4": "foo",
807+
}
786808

787-
s = TypedDoc.search().sort(TypedDoc.st, -TypedDoc.dt, +TypedDoc.ob.st)
788-
s.aggs.bucket("terms_agg", "terms", field=TypedDoc.k1)
789-
assert s.to_dict() == {
790-
"aggs": {"terms_agg": {"terms": {"field": "k1"}}},
791-
"sort": ["st", {"dt": {"order": "desc"}}, "ob.st"],
792-
}
809+
s = doc_class.search().sort(doc_class.st, -doc_class.dt, +doc_class.ob.st)
810+
s.aggs.bucket("terms_agg", "terms", field=doc_class.k1)
811+
assert s.to_dict() == {
812+
"aggs": {"terms_agg": {"terms": {"field": "k1"}}},
813+
"sort": ["st", {"dt": {"order": "desc"}}, "ob.st"],
814+
}
793815

794816

795817
@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires Python 3.10")

0 commit comments

Comments
 (0)