Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions elastic_transport/_async_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import asyncio
import logging
import time
from typing import (
Any,
Awaitable,
Expand All @@ -30,6 +31,8 @@
Union,
)

import sniffio

from ._compat import await_if_coro
from ._exceptions import (
ConnectionError,
Expand Down Expand Up @@ -169,6 +172,7 @@ def __init__(
# time it's needed. Gets set within '_async_call()' which should
# precede all logic within async calls.
self._loop: asyncio.AbstractEventLoop = None # type: ignore[assignment]
self._async_library: str = None # type: ignore[assignment]

# AsyncTransport doesn't require a thread lock for
# sniffing. Uses '_sniffing_task' instead.
Expand Down Expand Up @@ -258,7 +262,7 @@ async def perform_request( # type: ignore[override, return]
node_failure = False
last_response: Optional[TransportApiResponse] = None
node: BaseAsyncNode = self.node_pool.get() # type: ignore[assignment]
start_time = self._loop.time()
start_time = time.monotonic()
try:
otel_span.set_node_metadata(
node.host, node.port, node.base_url, target, method
Expand All @@ -277,7 +281,7 @@ async def perform_request( # type: ignore[override, return]
node.base_url,
target,
resp.meta.status,
self._loop.time() - start_time,
time.monotonic() - start_time,
)
)

Expand All @@ -300,7 +304,7 @@ async def perform_request( # type: ignore[override, return]
node.base_url,
target,
"N/A",
self._loop.time() - start_time,
time.monotonic() - start_time,
)
)

Expand Down Expand Up @@ -377,6 +381,10 @@ async def perform_request( # type: ignore[override, return]
)

async def sniff(self, is_initial_sniff: bool = False) -> None: # type: ignore[override]
if sniffio.current_async_library() == "trio":
raise ValueError(
f"Asynchronous sniffing is not supported with the 'trio' library, got {sniffio.current_async_library}"
)
await self._async_call()
task = self._create_sniffing_task(is_initial_sniff)

Expand Down Expand Up @@ -409,8 +417,7 @@ def _should_sniff(self, is_initial_sniff: bool) -> bool:
self._sniffing_task.result()

return (
self._loop.time() - self._last_sniffed_at
>= self._min_delay_between_sniffing
time.monotonic() - self._last_sniffed_at >= self._min_delay_between_sniffing
)

def _create_sniffing_task(
Expand All @@ -429,7 +436,7 @@ async def _sniffing_task_impl(self, is_initial_sniff: bool) -> None:
"""Implementation of the sniffing task"""
previously_sniffed_at = self._last_sniffed_at
try:
self._last_sniffed_at = self._loop.time()
self._last_sniffed_at = time.monotonic()
options = SniffOptions(
is_initial_sniff=is_initial_sniff, sniff_timeout=self._sniff_timeout
)
Expand Down Expand Up @@ -466,8 +473,13 @@ async def _async_call(self) -> None:
because we're not guaranteed to be within an active asyncio event loop
when __init__() is called.
"""
if self._loop is not None:
if self._async_library is not None:
return # Call at most once!

self._async_library = sniffio.current_async_library()
if self._async_library == "trio":
return

self._loop = asyncio.get_running_loop()
if self._sniff_on_start:
await self.sniff(True)
5 changes: 4 additions & 1 deletion elastic_transport/_node/_http_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ class RequestKwarg(TypedDict, total=False):


class AiohttpHttpNode(BaseAsyncNode):
"""Default asynchronous node class using the ``aiohttp`` library via HTTP"""
"""Default asynchronous node class using the ``aiohttp`` library via HTTP.

Supports asyncio.
"""

_CLIENT_META_HTTP_CLIENT = ("ai", _AIOHTTP_META_VERSION)

Expand Down
8 changes: 6 additions & 2 deletions elastic_transport/_node/_http_httpx.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@


class HttpxAsyncHttpNode(BaseAsyncNode):
"""
Async HTTP node using httpx. Supports both Trio and asyncio.
"""

_CLIENT_META_HTTP_CLIENT = ("hx", _HTTPX_META_VERSION)

def __init__(self, config: NodeConfig):
Expand Down Expand Up @@ -175,11 +179,11 @@ async def perform_request( # type: ignore[override]
body=body,
exception=err,
)
raise err from None
raise err from e

meta = ApiResponseMeta(
resp.status_code,
resp.http_version,
resp.http_version.lstrip("HTTP/"),
HttpHeaders(resp.headers),
duration,
self.config,
Expand Down
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
install_requires=[
"urllib3>=1.26.2, <3",
"certifi",
"sniffio",
],
python_requires=">=3.8",
extras_require={
Expand All @@ -70,6 +71,8 @@
"opentelemetry-api",
"opentelemetry-sdk",
"orjson",
"anyio",
"trio",
# Override Read the Docs default (sphinx<2)
"sphinx>2",
"furo",
Expand Down
75 changes: 45 additions & 30 deletions tests/async_/test_async_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
import warnings
from unittest import mock

import anyio
import pytest
import sniffio

from elastic_transport import (
AiohttpHttpNode,
Expand All @@ -45,9 +47,11 @@
from tests.conftest import AsyncDummyNode


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_async_transport_httpbin(httpbin_node_config, httpbin):
t = AsyncTransport([httpbin_node_config], meta_header=False)
t = AsyncTransport(
[httpbin_node_config], meta_header=False, node_class=HttpxAsyncHttpNode
)
resp, data = await t.perform_request("GET", "/anything?key=value")

assert resp.status == 200
Expand All @@ -57,6 +61,8 @@ async def test_async_transport_httpbin(httpbin_node_config, httpbin):

data["headers"].pop("X-Amzn-Trace-Id", None)
assert data["headers"] == {
"Accept": "*/*",
"Accept-Encoding": "gzip, deflate, br",
"User-Agent": DEFAULT_USER_AGENT,
"Connection": "keep-alive",
"Host": f"{httpbin.host}:{httpbin.port}",
Expand All @@ -66,15 +72,15 @@ async def test_async_transport_httpbin(httpbin_node_config, httpbin):
@pytest.mark.skipif(
sys.version_info < (3, 8), reason="Mock didn't support async before Python 3.8"
)
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_transport_close_node_pool():
t = AsyncTransport([NodeConfig("http", "localhost", 443)])
with mock.patch.object(t.node_pool.all()[0], "close") as node_close:
await t.close()
node_close.assert_called_with()


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_request_with_custom_user_agent_header():
t = AsyncTransport(
[NodeConfig("http", "localhost", 80)],
Expand All @@ -91,7 +97,7 @@ async def test_request_with_custom_user_agent_header():
} == t.node_pool.get().calls[0][1]


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_body_gets_encoded_into_bytes():
t = AsyncTransport([NodeConfig("http", "localhost", 80)], node_class=AsyncDummyNode)

Expand All @@ -105,7 +111,7 @@ async def test_body_gets_encoded_into_bytes():
assert kwargs["body"] == b'{"key":"\xe4\xbd\xa0\xe5\xa5\xbd"}'


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_body_bytes_get_passed_untouched():
t = AsyncTransport([NodeConfig("http", "localhost", 80)], node_class=AsyncDummyNode)

Expand All @@ -131,7 +137,7 @@ def test_kwargs_passed_on_to_node_pool():
assert dt is t.node_pool.max_dead_node_backoff


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_request_will_fail_after_x_retries():
t = AsyncTransport(
[
Expand All @@ -154,7 +160,7 @@ async def test_request_will_fail_after_x_retries():


@pytest.mark.parametrize("retry_on_timeout", [True, False])
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_retry_on_timeout(retry_on_timeout):
t = AsyncTransport(
[
Expand Down Expand Up @@ -189,7 +195,7 @@ async def test_retry_on_timeout(retry_on_timeout):
assert len(e.value.errors) == 0


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_retry_on_status():
t = AsyncTransport(
[
Expand Down Expand Up @@ -233,7 +239,7 @@ async def test_retry_on_status():
]


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_failed_connection_will_be_marked_as_dead():
t = AsyncTransport(
[
Expand Down Expand Up @@ -262,7 +268,7 @@ async def test_failed_connection_will_be_marked_as_dead():
assert all(isinstance(error, ConnectionError) for error in e.value.errors)


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_resurrected_connection_will_be_marked_as_live_on_success():
for method in ("GET", "HEAD"):
t = AsyncTransport(
Expand All @@ -283,7 +289,7 @@ async def test_resurrected_connection_will_be_marked_as_live_on_success():
assert 1 == len(t.node_pool._dead_nodes.queue)


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_mark_dead_error_doesnt_raise():
t = AsyncTransport(
[
Expand All @@ -303,7 +309,7 @@ async def test_mark_dead_error_doesnt_raise():
mark_dead.assert_called_with(bad_node)


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_node_class_as_string():
t = AsyncTransport([NodeConfig("http", "localhost", 80)], node_class="aiohttp")
assert isinstance(t.node_pool.get(), AiohttpHttpNode)
Expand All @@ -320,7 +326,7 @@ async def test_node_class_as_string():


@pytest.mark.parametrize(["status", "boolean"], [(200, True), (299, True)])
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_head_response_true(status, boolean):
t = AsyncTransport(
[NodeConfig("http", "localhost", 80, _extras={"status": status, "body": b""})],
Expand All @@ -331,7 +337,7 @@ async def test_head_response_true(status, boolean):
assert data is None


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_head_response_false():
t = AsyncTransport(
[NodeConfig("http", "localhost", 80, _extras={"status": 404, "body": b""})],
Expand All @@ -353,7 +359,7 @@ async def test_head_response_false():
(HttpxAsyncHttpNode, "hx"),
],
)
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_transport_client_meta_node_class(node_class, client_short_name):
t = AsyncTransport([NodeConfig("http", "localhost", 80)], node_class=node_class)
assert (
Expand All @@ -366,7 +372,7 @@ async def test_transport_client_meta_node_class(node_class, client_short_name):
)


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_transport_default_client_meta_node_class():
# Defaults to aiohttp
t = AsyncTransport(
Expand Down Expand Up @@ -635,7 +641,7 @@ async def test_sniff_on_start_no_results_errors(sniff_callback):


@pytest.mark.parametrize("pool_size", [1, 8])
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_multiple_tasks_test(pool_size):
node_configs = [
NodeConfig("http", "localhost", 80),
Expand All @@ -648,34 +654,43 @@ async def sniff_callback(*_):
await asyncio.sleep(random.random())
return node_configs

kwargs = {}
if sniffio.current_async_library() == "asyncio":
kwargs = {
"sniff_on_start": True,
"sniff_before_requests": True,
"sniff_on_node_failure": True,
"sniff_callback": sniff_callback,
}

t = AsyncTransport(
node_configs,
retry_on_status=[500],
max_retries=5,
node_class=AsyncDummyNode,
sniff_on_start=True,
sniff_before_requests=True,
sniff_on_node_failure=True,
sniff_callback=sniff_callback,
**kwargs,
)

loop = asyncio.get_running_loop()
start = loop.time()
start = time.monotonic()

successful_requests = 0

async def run_requests():
successful_requests = 0
while loop.time() - start < 2:
nonlocal successful_requests
while time.monotonic() - start < 2:
await t.perform_request("GET", "/")
successful_requests += 1
return successful_requests

tasks = [loop.create_task(run_requests()) for _ in range(pool_size * 2)]
assert sum([await task for task in tasks]) >= 1000
async with anyio.create_task_group() as tg:
for _ in range(pool_size * 2):
tg.start_soon(run_requests)
assert successful_requests >= 1000


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_httpbin(httpbin_node_config):
t = AsyncTransport([httpbin_node_config])
t = AsyncTransport([httpbin_node_config], node_class=HttpxAsyncHttpNode)
resp = await t.perform_request("GET", "/anything")
assert resp.meta.status == 200
assert isinstance(resp.body, dict)
Loading
Loading