Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions elastic_transport/_async_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import asyncio
import logging
import time
from typing import (
Any,
Awaitable,
Expand All @@ -30,6 +31,8 @@
Union,
)

import sniffio

from ._compat import await_if_coro
from ._exceptions import (
ConnectionError,
Expand Down Expand Up @@ -169,6 +172,7 @@ def __init__(
# time it's needed. Gets set within '_async_call()' which should
# precede all logic within async calls.
self._loop: asyncio.AbstractEventLoop = None # type: ignore[assignment]
self._async_library: str = None # type: ignore[assignment]

# AsyncTransport doesn't require a thread lock for
# sniffing. Uses '_sniffing_task' instead.
Expand Down Expand Up @@ -258,7 +262,7 @@ async def perform_request( # type: ignore[override, return]
node_failure = False
last_response: Optional[TransportApiResponse] = None
node: BaseAsyncNode = self.node_pool.get() # type: ignore[assignment]
start_time = self._loop.time()
start_time = time.monotonic()
try:
otel_span.set_node_metadata(
node.host, node.port, node.base_url, target, method
Expand All @@ -277,7 +281,7 @@ async def perform_request( # type: ignore[override, return]
node.base_url,
target,
resp.meta.status,
self._loop.time() - start_time,
time.monotonic() - start_time,
)
)

Expand All @@ -300,7 +304,7 @@ async def perform_request( # type: ignore[override, return]
node.base_url,
target,
"N/A",
self._loop.time() - start_time,
time.monotonic() - start_time,
)
)

Expand Down Expand Up @@ -377,6 +381,10 @@ async def perform_request( # type: ignore[override, return]
)

async def sniff(self, is_initial_sniff: bool = False) -> None: # type: ignore[override]
if sniffio.current_async_library() != "asyncio":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be more appropriate here to show the error when current_async_library == "trio", since we know for a fact that sniffing does not work with it? We don't know if sniffing works or not with other async libraries, I assume. And also, this function raises an exception when it cannot recognize which library is in use. For backwards compatibility maybe we should allow the sniffing in such a case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What other async library?

  • I haven't seen https://github.com/dabeaz/curio used anywhere, and it isn't making new releases anyway.
  • Twisted isn't supported by sniffio or anyio
  • trio-asyncio never really took off, and switching to current_async_library == trio won't make any difference

I wouldn't see a different library making the same mistakes as asyncio, ie. allowing background tasks. Anyways, Trio has shown how hard it is for a different async library to take off, even with superior design, excellent docs, etc. So, for me, the condition makes sense: today, we only support sniffing with asyncio.

I find it interesting how this is exactly the discussion we're having in elastic/elasticsearch-py#3103 (comment).

raise ValueError(
f"Asynchronous sniffing only works with the 'asyncio' library, got {sniffio.current_async_library}"
)
await self._async_call()
task = self._create_sniffing_task(is_initial_sniff)

Expand Down Expand Up @@ -409,8 +417,7 @@ def _should_sniff(self, is_initial_sniff: bool) -> bool:
self._sniffing_task.result()

return (
self._loop.time() - self._last_sniffed_at
>= self._min_delay_between_sniffing
time.monotonic() - self._last_sniffed_at >= self._min_delay_between_sniffing
)

def _create_sniffing_task(
Expand All @@ -429,7 +436,7 @@ async def _sniffing_task_impl(self, is_initial_sniff: bool) -> None:
"""Implementation of the sniffing task"""
previously_sniffed_at = self._last_sniffed_at
try:
self._last_sniffed_at = self._loop.time()
self._last_sniffed_at = time.monotonic()
options = SniffOptions(
is_initial_sniff=is_initial_sniff, sniff_timeout=self._sniff_timeout
)
Expand Down Expand Up @@ -466,8 +473,13 @@ async def _async_call(self) -> None:
because we're not guaranteed to be within an active asyncio event loop
when __init__() is called.
"""
if self._loop is not None:
if self._async_library is not None:
return # Call at most once!

self._async_library = sniffio.current_async_library()
if self._async_library != "asyncio":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above here. This is unlikely, but let's say that someone using "curio" was able to make sniffing work before. This would prevent them from using this feature. I think the check should be for == "trio", which is the one case that we know fails.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, curio allows daemon tasks: https://curio.readthedocs.io/en/latest/reference.html#spawn. But I'm not aware of any HTTP client that supports curio. Since AnyIO dropped support for Curio, even asks does not support it. HTTPX does not support it.

return

self._loop = asyncio.get_running_loop()
if self._sniff_on_start:
await self.sniff(True)
5 changes: 4 additions & 1 deletion elastic_transport/_node/_http_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ class RequestKwarg(TypedDict, total=False):


class AiohttpHttpNode(BaseAsyncNode):
"""Default asynchronous node class using the ``aiohttp`` library via HTTP"""
"""Default asynchronous node class using the ``aiohttp`` library via HTTP.

Supports asyncio.
"""

_CLIENT_META_HTTP_CLIENT = ("ai", _AIOHTTP_META_VERSION)

Expand Down
8 changes: 6 additions & 2 deletions elastic_transport/_node/_http_httpx.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@


class HttpxAsyncHttpNode(BaseAsyncNode):
"""
Async HTTP node using httpx. Supports both Trio and asyncio.
"""

_CLIENT_META_HTTP_CLIENT = ("hx", _HTTPX_META_VERSION)

def __init__(self, config: NodeConfig):
Expand Down Expand Up @@ -175,11 +179,11 @@ async def perform_request( # type: ignore[override]
body=body,
exception=err,
)
raise err from None
raise err from e

meta = ApiResponseMeta(
resp.status_code,
resp.http_version,
resp.http_version.lstrip("HTTP/"),
HttpHeaders(resp.headers),
duration,
self.config,
Expand Down
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
install_requires=[
"urllib3>=1.26.2, <3",
"certifi",
"sniffio",
],
python_requires=">=3.8",
extras_require={
Expand All @@ -70,6 +71,8 @@
"opentelemetry-api",
"opentelemetry-sdk",
"orjson",
"anyio",
"trio",
# Override Read the Docs default (sphinx<2)
"sphinx>2",
"furo",
Expand Down
77 changes: 47 additions & 30 deletions tests/async_/test_async_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
import warnings
from unittest import mock

import anyio
import pytest
import sniffio

from elastic_transport import (
AiohttpHttpNode,
Expand All @@ -45,9 +47,11 @@
from tests.conftest import AsyncDummyNode


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_async_transport_httpbin(httpbin_node_config, httpbin):
t = AsyncTransport([httpbin_node_config], meta_header=False)
t = AsyncTransport(
[httpbin_node_config], meta_header=False, node_class=HttpxAsyncHttpNode
)
resp, data = await t.perform_request("GET", "/anything?key=value")

assert resp.status == 200
Expand All @@ -57,6 +61,8 @@ async def test_async_transport_httpbin(httpbin_node_config, httpbin):

data["headers"].pop("X-Amzn-Trace-Id", None)
assert data["headers"] == {
"Accept": "*/*",
"Accept-Encoding": "gzip, deflate, br",
"User-Agent": DEFAULT_USER_AGENT,
"Connection": "keep-alive",
"Host": f"{httpbin.host}:{httpbin.port}",
Expand All @@ -66,15 +72,15 @@ async def test_async_transport_httpbin(httpbin_node_config, httpbin):
@pytest.mark.skipif(
sys.version_info < (3, 8), reason="Mock didn't support async before Python 3.8"
)
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_transport_close_node_pool():
t = AsyncTransport([NodeConfig("http", "localhost", 443)])
with mock.patch.object(t.node_pool.all()[0], "close") as node_close:
await t.close()
node_close.assert_called_with()


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_request_with_custom_user_agent_header():
t = AsyncTransport(
[NodeConfig("http", "localhost", 80)],
Expand All @@ -91,7 +97,7 @@ async def test_request_with_custom_user_agent_header():
} == t.node_pool.get().calls[0][1]


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_body_gets_encoded_into_bytes():
t = AsyncTransport([NodeConfig("http", "localhost", 80)], node_class=AsyncDummyNode)

Expand All @@ -105,7 +111,7 @@ async def test_body_gets_encoded_into_bytes():
assert kwargs["body"] == b'{"key":"\xe4\xbd\xa0\xe5\xa5\xbd"}'


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_body_bytes_get_passed_untouched():
t = AsyncTransport([NodeConfig("http", "localhost", 80)], node_class=AsyncDummyNode)

Expand All @@ -131,7 +137,7 @@ def test_kwargs_passed_on_to_node_pool():
assert dt is t.node_pool.max_dead_node_backoff


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_request_will_fail_after_x_retries():
t = AsyncTransport(
[
Expand All @@ -154,7 +160,7 @@ async def test_request_will_fail_after_x_retries():


@pytest.mark.parametrize("retry_on_timeout", [True, False])
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_retry_on_timeout(retry_on_timeout):
t = AsyncTransport(
[
Expand Down Expand Up @@ -189,7 +195,7 @@ async def test_retry_on_timeout(retry_on_timeout):
assert len(e.value.errors) == 0


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_retry_on_status():
t = AsyncTransport(
[
Expand Down Expand Up @@ -233,7 +239,7 @@ async def test_retry_on_status():
]


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_failed_connection_will_be_marked_as_dead():
t = AsyncTransport(
[
Expand Down Expand Up @@ -262,7 +268,7 @@ async def test_failed_connection_will_be_marked_as_dead():
assert all(isinstance(error, ConnectionError) for error in e.value.errors)


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_resurrected_connection_will_be_marked_as_live_on_success():
for method in ("GET", "HEAD"):
t = AsyncTransport(
Expand All @@ -283,7 +289,7 @@ async def test_resurrected_connection_will_be_marked_as_live_on_success():
assert 1 == len(t.node_pool._dead_nodes.queue)


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_mark_dead_error_doesnt_raise():
t = AsyncTransport(
[
Expand All @@ -303,7 +309,7 @@ async def test_mark_dead_error_doesnt_raise():
mark_dead.assert_called_with(bad_node)


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_node_class_as_string():
t = AsyncTransport([NodeConfig("http", "localhost", 80)], node_class="aiohttp")
assert isinstance(t.node_pool.get(), AiohttpHttpNode)
Expand All @@ -320,7 +326,7 @@ async def test_node_class_as_string():


@pytest.mark.parametrize(["status", "boolean"], [(200, True), (299, True)])
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_head_response_true(status, boolean):
t = AsyncTransport(
[NodeConfig("http", "localhost", 80, _extras={"status": status, "body": b""})],
Expand All @@ -331,7 +337,7 @@ async def test_head_response_true(status, boolean):
assert data is None


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_head_response_false():
t = AsyncTransport(
[NodeConfig("http", "localhost", 80, _extras={"status": 404, "body": b""})],
Expand All @@ -353,7 +359,7 @@ async def test_head_response_false():
(HttpxAsyncHttpNode, "hx"),
],
)
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_transport_client_meta_node_class(node_class, client_short_name):
t = AsyncTransport([NodeConfig("http", "localhost", 80)], node_class=node_class)
assert (
Expand All @@ -366,7 +372,7 @@ async def test_transport_client_meta_node_class(node_class, client_short_name):
)


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_transport_default_client_meta_node_class():
# Defaults to aiohttp
t = AsyncTransport(
Expand Down Expand Up @@ -635,7 +641,7 @@ async def test_sniff_on_start_no_results_errors(sniff_callback):


@pytest.mark.parametrize("pool_size", [1, 8])
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_multiple_tasks_test(pool_size):
node_configs = [
NodeConfig("http", "localhost", 80),
Expand All @@ -648,34 +654,45 @@ async def sniff_callback(*_):
await asyncio.sleep(random.random())
return node_configs

kwargs = {}
if sniffio.current_async_library() == "asyncio":
kwargs = {
"sniff_on_start": True,
"sniff_before_requests": True,
"sniff_on_node_failure": True,
"sniff_callback": sniff_callback,
}

print(kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this print intended or left over from a debug session?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, removed!


t = AsyncTransport(
node_configs,
retry_on_status=[500],
max_retries=5,
node_class=AsyncDummyNode,
sniff_on_start=True,
sniff_before_requests=True,
sniff_on_node_failure=True,
sniff_callback=sniff_callback,
**kwargs,
)

loop = asyncio.get_running_loop()
start = loop.time()
start = time.monotonic()

successful_requests = 0

async def run_requests():
successful_requests = 0
while loop.time() - start < 2:
nonlocal successful_requests
while time.monotonic() - start < 2:
await t.perform_request("GET", "/")
successful_requests += 1
return successful_requests

tasks = [loop.create_task(run_requests()) for _ in range(pool_size * 2)]
assert sum([await task for task in tasks]) >= 1000
async with anyio.create_task_group() as tg:
for _ in range(pool_size * 2):
tg.start_soon(run_requests)
assert successful_requests >= 1000


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_httpbin(httpbin_node_config):
t = AsyncTransport([httpbin_node_config])
t = AsyncTransport([httpbin_node_config], node_class=HttpxAsyncHttpNode)
resp = await t.perform_request("GET", "/anything")
assert resp.meta.status == 200
assert isinstance(resp.body, dict)
Loading