Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions docs/reference/async.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,27 @@ All APIs that are available under the sync client are also available under the a

See also the [Using OpenTelemetry](/reference/opentelemetry.md) page.

## Trio support

If you prefer using Trio instead of asyncio to take advantage of its better structured concurrency support, you can use the HTTPX async node which supports Trio out of the box.

```python
import trio
from elasticsearch import AsyncElasticsearch

client = AsyncElasticsearch(
"https://...",
api_key="...",
node_class="httpxasync")

async def main():
resp = await client.info()
print(resp.body)

trio.run(main)
```

The one limitation of Trio support is that it does not currently support node sniffing, which was not implemented with structured concurrency in mind.

## Frequently Asked Questions [_frequently_asked_questions]

Expand Down
16 changes: 14 additions & 2 deletions docs/reference/dsl_how_to_guides.md
Original file line number Diff line number Diff line change
Expand Up @@ -1555,6 +1555,12 @@ The DSL module supports async/await with [asyncio](https://docs.python.org/3/lib
$ python -m pip install "elasticsearch[async]"
```

The DSL module also supports [Trio](https://trio.readthedocs.io/en/stable/) when using the Async HTTPX client. You do need to install Trio and HTTPX separately:

```bash
$ python -m pip install "elasticsearch trio httpx"
```

### Connections [_connections]

Use the `async_connections` module to manage your asynchronous connections.
Expand All @@ -1565,6 +1571,14 @@ from elasticsearch.dsl import async_connections
async_connections.create_connection(hosts=['localhost'], timeout=20)
```

If you're using Trio, you need to explicitly request the Async HTTP client:

```python
from elasticsearch.dsl import async_connections

async_connections.create_connection(hosts=['localhost'], node_class="httpxasync")
```

All the options available in the `connections` module can be used with `async_connections`.

#### How to avoid *Unclosed client session / connector* warnings on exit [_how_to_avoid_unclosed_client_session_connector_warnings_on_exit]
Expand All @@ -1576,8 +1590,6 @@ es = async_connections.get_connection()
await es.close()
```



### Search DSL [_search_dsl]

Use the `AsyncSearch` class to perform asynchronous searches.
Expand Down
43 changes: 28 additions & 15 deletions elasticsearch/_async/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
Union,
)

from ..compat import safe_task
import sniffio
from anyio import create_memory_object_stream, create_task_group, move_on_after

from ..exceptions import ApiError, NotFoundError, TransportError
from ..helpers.actions import (
_TYPE_BULK_ACTION,
Expand All @@ -57,6 +59,15 @@
T = TypeVar("T")


async def _sleep(seconds: float) -> None:
if sniffio.current_async_library() == "trio":
import trio

await trio.sleep(seconds)
else:
await asyncio.sleep(seconds)


async def _chunk_actions(
actions: AsyncIterable[_TYPE_BULK_ACTION_HEADER_WITH_META_AND_BODY],
chunk_size: int,
Expand All @@ -82,32 +93,36 @@ async def _chunk_actions(
chunk_size=chunk_size, max_chunk_bytes=max_chunk_bytes, serializer=serializer
)

action: _TYPE_BULK_ACTION_WITH_META
data: _TYPE_BULK_ACTION_BODY
if not flush_after_seconds:
async for action, data in actions:
ret = chunker.feed(action, data)
if ret:
yield ret
else:
item_queue: asyncio.Queue[_TYPE_BULK_ACTION_HEADER_WITH_META_AND_BODY] = (
asyncio.Queue()
)
sender, receiver = create_memory_object_stream[
_TYPE_BULK_ACTION_HEADER_WITH_META_AND_BODY
]()

async def get_items() -> None:
try:
async for item in actions:
await item_queue.put(item)
await sender.send(item)
finally:
await item_queue.put((BulkMeta.done, None))
await sender.send((BulkMeta.done, None))

async with create_task_group() as tg:
tg.start_soon(get_items)

async with safe_task(get_items()):
timeout: Optional[float] = flush_after_seconds
while True:
try:
action, data = await asyncio.wait_for(
item_queue.get(), timeout=timeout
)
action = {}
data = None
with move_on_after(timeout) as scope:
action, data = await receiver.receive()
timeout = flush_after_seconds
except asyncio.TimeoutError:
if scope.cancelled_caught:
action, data = BulkMeta.flush, None
timeout = None

Expand Down Expand Up @@ -294,9 +309,7 @@ async def map_actions() -> (
]
] = []
if attempt:
await asyncio.sleep(
min(max_backoff, initial_backoff * 2 ** (attempt - 1))
)
await _sleep(min(max_backoff, initial_backoff * 2 ** (attempt - 1)))

try:
data: Union[
Expand Down
17 changes: 2 additions & 15 deletions elasticsearch/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@
# specific language governing permissions and limitations
# under the License.

import asyncio
import inspect
import os
import sys
from contextlib import asynccontextmanager, contextmanager
from contextlib import contextmanager
from pathlib import Path
from threading import Thread
from typing import Any, AsyncIterator, Callable, Coroutine, Iterator, Tuple, Type, Union
from typing import Any, Callable, Iterator, Tuple, Type, Union

string_types: Tuple[Type[str], Type[bytes]] = (str, bytes)

Expand Down Expand Up @@ -105,22 +104,10 @@ def run() -> None:
raise captured_exception


@asynccontextmanager
async def safe_task(coro: Coroutine[Any, Any, Any]) -> AsyncIterator[asyncio.Task[Any]]:
"""Run a background task within a context manager block.

The task is awaited when the block ends.
"""
task = asyncio.create_task(coro)
yield task
await task


__all__ = [
"string_types",
"to_str",
"to_bytes",
"warn_stacklevel",
"safe_thread",
"safe_task",
]
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@ keywords = [
]
dynamic = ["version"]
dependencies = [
"elastic-transport>=9.1.0,<10",
"elastic-transport>=9.2.0,<10",
"python-dateutil",
"typing-extensions",
"sniffio",
"anyio",
]

[project.optional-dependencies]
Expand All @@ -55,6 +57,7 @@ vectorstore_mmr = ["numpy>=1", "simsimd>=3"]
dev = [
"requests>=2, <3",
"aiohttp",
"httpx",
"pytest",
"pytest-cov",
"pytest-mock",
Expand All @@ -77,6 +80,7 @@ dev = [
"mapbox-vector-tile",
"jinja2",
"tqdm",
"trio",
"mypy",
"pyright",
"types-python-dateutil",
Expand Down
17 changes: 8 additions & 9 deletions test_elasticsearch/test_async/test_server/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,26 @@
# under the License.

import pytest
import pytest_asyncio
import sniffio

import elasticsearch

from ...utils import CA_CERTS, wipe_cluster

pytestmark = pytest.mark.asyncio


@pytest_asyncio.fixture(scope="function")
@pytest.fixture(scope="function")
async def async_client_factory(elasticsearch_url):

if not hasattr(elasticsearch, "AsyncElasticsearch"):
pytest.skip("test requires 'AsyncElasticsearch' and aiohttp to be installed")

kwargs = {}
if sniffio.current_async_library() == "trio":
kwargs["node_class"] = "httpxasync"
# Unfortunately the asyncio client needs to be rebuilt every
# test execution due to how pytest-asyncio manages
# event loops (one per test!)
client = None
try:
client = elasticsearch.AsyncElasticsearch(elasticsearch_url, ca_certs=CA_CERTS)
client = elasticsearch.AsyncElasticsearch(
elasticsearch_url, ca_certs=CA_CERTS, **kwargs
)
yield client
finally:
if client:
Expand Down
2 changes: 1 addition & 1 deletion test_elasticsearch/test_async/test_server/test_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import pytest

pytestmark = pytest.mark.asyncio
pytestmark = pytest.mark.anyio


@pytest.mark.parametrize("kwargs", [{"body": {"text": "привет"}}, {"text": "привет"}])
Expand Down
19 changes: 9 additions & 10 deletions test_elasticsearch/test_async/test_server/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,20 @@
# specific language governing permissions and limitations
# under the License.

import asyncio
import logging
import time
from datetime import datetime, timedelta, timezone
from unittest.mock import MagicMock, call, patch

import anyio
import pytest
import pytest_asyncio
from elastic_transport import ApiResponseMeta, ObjectApiResponse

from elasticsearch import helpers
from elasticsearch.exceptions import ApiError
from elasticsearch.helpers import ScanError

pytestmark = [pytest.mark.asyncio]
pytestmark = pytest.mark.anyio


class AsyncMock(MagicMock):
Expand Down Expand Up @@ -93,7 +92,7 @@ async def test_all_documents_get_inserted(self, async_client):
async def test_documents_data_types(self, async_client):
async def async_gen():
for x in range(100):
await asyncio.sleep(0)
await anyio.sleep(0)
yield {"answer": x, "_id": x}

def sync_gen():
Expand Down Expand Up @@ -129,7 +128,7 @@ async def async_gen():
yield {"answer": 2, "_id": 0}
yield {"answer": 1, "_id": 1}
yield helpers.BULK_FLUSH
await asyncio.sleep(0.5)
await anyio.sleep(0.5)
yield {"answer": 2, "_id": 2}

timestamps = []
Expand All @@ -146,7 +145,7 @@ async def test_timeout_flushes(self, async_client):
async def async_gen():
yield {"answer": 2, "_id": 0}
yield {"answer": 1, "_id": 1}
await asyncio.sleep(0.5)
await anyio.sleep(0.5)
yield {"answer": 2, "_id": 2}

timestamps = []
Expand Down Expand Up @@ -531,7 +530,7 @@ def __await__(self):
return self().__await__()


@pytest_asyncio.fixture(scope="function")
@pytest.fixture(scope="function")
async def scan_teardown(async_client):
yield
await async_client.clear_scroll(scroll_id="_all")
Expand Down Expand Up @@ -955,7 +954,7 @@ async def test_scan_from_keyword_is_aliased(async_client, scan_kwargs):
assert "from" not in search_mock.call_args[1]


@pytest_asyncio.fixture(scope="function")
@pytest.fixture(scope="function")
async def reindex_setup(async_client):
bulk = []
for x in range(100):
Expand Down Expand Up @@ -1033,7 +1032,7 @@ async def test_all_documents_get_moved(self, async_client, reindex_setup):
)["_source"]


@pytest_asyncio.fixture(scope="function")
@pytest.fixture(scope="function")
async def parent_reindex_setup(async_client):
body = {
"settings": {"number_of_shards": 1, "number_of_replicas": 0},
Expand Down Expand Up @@ -1094,7 +1093,7 @@ async def test_children_are_reindexed_correctly(
} == q


@pytest_asyncio.fixture(scope="function")
@pytest.fixture(scope="function")
async def reindex_data_stream_setup(async_client):
dt = datetime.now(tz=timezone.utc)
bulk = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@
# under the License.

import pytest
import pytest_asyncio

from elasticsearch import RequestError

pytestmark = pytest.mark.asyncio
pytestmark = pytest.mark.anyio


@pytest_asyncio.fixture(scope="function")
@pytest.fixture(scope="function")
async def mvt_setup(async_client):
await async_client.indices.create(
index="museums",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import warnings

import pytest
import pytest_asyncio

from elasticsearch import ElasticsearchWarning, RequestError

Expand All @@ -39,6 +38,8 @@
)
from ...utils import parse_version

# We're not using `pytest.mark.anyio` here because it would run the test suite twice,
# which does not work as it does not fully clean up after itself.
pytestmark = pytest.mark.asyncio

XPACK_FEATURES = None
Expand Down Expand Up @@ -240,7 +241,7 @@ async def _feature_enabled(self, name):
return name in XPACK_FEATURES


@pytest_asyncio.fixture(scope="function")
@pytest.fixture(scope="function")
def async_runner(async_client_factory):
return AsyncYamlRunner(async_client_factory)

Expand Down
Loading
Loading