Skip to content

Commit 8c38d16

Browse files
more minor adjustments to response types
1 parent 099a52f commit 8c38d16

File tree

9 files changed

+31
-15
lines changed

9 files changed

+31
-15
lines changed

elasticsearch_dsl/response/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,14 @@
4040
from ..search_base import Request, SearchBase
4141
from ..update_by_query_base import UpdateByQueryBase
4242

43-
__all__ = ["Response", "AggResponse", "UpdateByQueryResponse", "Hit", "HitMeta"]
43+
__all__ = [
44+
"Response",
45+
"AggResponse",
46+
"UpdateByQueryResponse",
47+
"Hit",
48+
"HitMeta",
49+
"AggregateResponseType",
50+
]
4451

4552

4653
class Response(AttrDict[Any], Generic[_R]):

elasticsearch_dsl/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from elastic_transport.client_utils import DEFAULT, DefaultType
2121

22-
from elasticsearch_dsl import Query, function
22+
from elasticsearch_dsl import Query, function, index_base
2323
from elasticsearch_dsl.document_base import InstrumentedField
2424
from elasticsearch_dsl.utils import AttrDict
2525

@@ -5100,7 +5100,7 @@ class Hit(AttrDict[Any]):
51005100
:arg sort:
51015101
"""
51025102

5103-
index: str
5103+
index: index_base.IndexBase
51045104
id: str
51055105
score: Union[float, None]
51065106
explanation: "Explanation"

examples/async/composite_agg.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717

1818
import asyncio
1919
import os
20-
from typing import Any, AsyncIterator, Dict, Mapping, Sequence
20+
from typing import Any, AsyncIterator, Dict, Mapping, Sequence, cast
2121

2222
from elasticsearch.helpers import async_bulk
2323

2424
from elasticsearch_dsl import Agg, AsyncSearch, Response, aggs, async_connections
25+
from elasticsearch_dsl.types import CompositeAggregate
2526
from tests.test_integration.test_data import DATA, GIT_INDEX
2627

2728

@@ -30,7 +31,7 @@ async def scan_aggs(
3031
source_aggs: Sequence[Mapping[str, Agg]],
3132
inner_aggs: Dict[str, Agg] = {},
3233
size: int = 10,
33-
) -> AsyncIterator[Any]:
34+
) -> AsyncIterator[CompositeAggregate]:
3435
"""
3536
Helper function used to iterate over all possible bucket combinations of
3637
``source_aggs``, returning results of ``inner_aggs`` for each. Uses the
@@ -54,7 +55,7 @@ async def run_search(**kwargs: Any) -> Response:
5455
response = await run_search()
5556
while response.aggregations["comp"].buckets:
5657
for b in response.aggregations["comp"].buckets:
57-
yield b
58+
yield cast(CompositeAggregate, b)
5859
if "after_key" in response.aggregations["comp"]:
5960
after = response.aggregations["comp"].after_key
6061
else:

examples/async/parent_child.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ async def add_answer(
162162
# required make sure the answer is stored in the same shard
163163
_routing=self.meta.id,
164164
# since we don't have explicit index, ensure same index as self
165-
_index=self.meta.index,
165+
_index=cast(AsyncIndex, self.meta.index),
166166
# set up the parent/child mapping
167167
question_answer={"name": "answer", "parent": self.meta.id},
168168
# pass in the field values
@@ -218,7 +218,7 @@ async def get_question(self) -> Optional[Question]:
218218
# any attributes set on self would be interpreted as fields
219219
if "question" not in self.meta:
220220
self.meta.question = await Question.get(
221-
id=self.question_answer.parent, index=self.meta.index
221+
id=self.question_answer.parent, index=self.meta.index._name
222222
)
223223
return cast(Optional[Question], self.meta.question)
224224

examples/composite_agg.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@
1616
# under the License.
1717

1818
import os
19-
from typing import Any, Dict, Iterator, Mapping, Sequence
19+
from typing import Any, Dict, Iterator, Mapping, Sequence, cast
2020

2121
from elasticsearch.helpers import bulk
2222

2323
from elasticsearch_dsl import Agg, Response, Search, aggs, connections
24+
from elasticsearch_dsl.types import CompositeAggregate
2425
from tests.test_integration.test_data import DATA, GIT_INDEX
2526

2627

@@ -29,7 +30,7 @@ def scan_aggs(
2930
source_aggs: Sequence[Mapping[str, Agg]],
3031
inner_aggs: Dict[str, Agg] = {},
3132
size: int = 10,
32-
) -> Iterator[Any]:
33+
) -> Iterator[CompositeAggregate]:
3334
"""
3435
Helper function used to iterate over all possible bucket combinations of
3536
``source_aggs``, returning results of ``inner_aggs`` for each. Uses the
@@ -53,7 +54,7 @@ def run_search(**kwargs: Any) -> Response:
5354
response = run_search()
5455
while response.aggregations["comp"].buckets:
5556
for b in response.aggregations["comp"].buckets:
56-
yield b
57+
yield cast(CompositeAggregate, b)
5758
if "after_key" in response.aggregations["comp"]:
5859
after = response.aggregations["comp"].after_key
5960
else:

examples/parent_child.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def add_answer(
161161
# required make sure the answer is stored in the same shard
162162
_routing=self.meta.id,
163163
# since we don't have explicit index, ensure same index as self
164-
_index=self.meta.index,
164+
_index=cast(Index, self.meta.index),
165165
# set up the parent/child mapping
166166
question_answer={"name": "answer", "parent": self.meta.id},
167167
# pass in the field values
@@ -217,7 +217,7 @@ def get_question(self) -> Optional[Question]:
217217
# any attributes set on self would be interpreted as fields
218218
if "question" not in self.meta:
219219
self.meta.question = Question.get(
220-
id=self.question_answer.parent, index=self.meta.index
220+
id=self.question_answer.parent, index=self.meta.index._name
221221
)
222222
return cast(Optional[Question], self.meta.question)
223223

utils/generator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -742,6 +742,13 @@ def interface_to_python_class(
742742
k, arg, for_types_py=for_types_py, for_response=for_response
743743
)
744744

745+
if interface == "Hit" and arg["name"] == "index":
746+
# Python DSL replaces the string typed index attribute
747+
# with an Index or AsyncIndex instance. Here we use
748+
# IndexBase, which is a base class for both Index and
749+
# AsyncIndex.
750+
k["args"][-1]["type"] = "index_base.IndexBase"
751+
745752
if "inherits" not in type_ or "type" not in type_["inherits"]:
746753
break
747754

utils/templates/response.__init__.py.tpl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ if TYPE_CHECKING:
4040
from ..update_by_query_base import UpdateByQueryBase
4141
from .. import types
4242

43-
__all__ = ["Response", "AggResponse", "UpdateByQueryResponse", "Hit", "HitMeta"]
43+
__all__ = ["Response", "AggResponse", "UpdateByQueryResponse", "Hit", "HitMeta", "AggregateResponseType"]
4444

4545

4646
class Response(AttrDict[Any], Generic[_R]):

utils/templates/types.py.tpl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ from typing import Any, Dict, Literal, Mapping, Sequence, Union
2020
from elastic_transport.client_utils import DEFAULT, DefaultType
2121

2222
from elasticsearch_dsl.document_base import InstrumentedField
23-
from elasticsearch_dsl import function, Query
23+
from elasticsearch_dsl import function, index_base, Query
2424
from elasticsearch_dsl.utils import AttrDict
2525

2626
PipeSeparatedFlags = str

0 commit comments

Comments
 (0)