diff --git a/elastic_transport/_async_transport.py b/elastic_transport/_async_transport.py index 3301714..b51120c 100644 --- a/elastic_transport/_async_transport.py +++ b/elastic_transport/_async_transport.py @@ -17,6 +17,7 @@ import asyncio import logging +import time from typing import ( Any, Awaitable, @@ -30,6 +31,8 @@ Union, ) +import sniffio + from ._compat import await_if_coro from ._exceptions import ( ConnectionError, @@ -169,6 +172,7 @@ def __init__( # time it's needed. Gets set within '_async_call()' which should # precede all logic within async calls. self._loop: asyncio.AbstractEventLoop = None # type: ignore[assignment] + self._async_library: str = None # type: ignore[assignment] # AsyncTransport doesn't require a thread lock for # sniffing. Uses '_sniffing_task' instead. @@ -258,7 +262,7 @@ async def perform_request( # type: ignore[override, return] node_failure = False last_response: Optional[TransportApiResponse] = None node: BaseAsyncNode = self.node_pool.get() # type: ignore[assignment] - start_time = self._loop.time() + start_time = time.monotonic() try: otel_span.set_node_metadata( node.host, node.port, node.base_url, target, method @@ -277,7 +281,7 @@ async def perform_request( # type: ignore[override, return] node.base_url, target, resp.meta.status, - self._loop.time() - start_time, + time.monotonic() - start_time, ) ) @@ -300,7 +304,7 @@ async def perform_request( # type: ignore[override, return] node.base_url, target, "N/A", - self._loop.time() - start_time, + time.monotonic() - start_time, ) ) @@ -377,6 +381,10 @@ async def perform_request( # type: ignore[override, return] ) async def sniff(self, is_initial_sniff: bool = False) -> None: # type: ignore[override] + if sniffio.current_async_library() == "trio": + raise ValueError( + f"Asynchronous sniffing is not supported with the 'trio' library, got {sniffio.current_async_library}" + ) await self._async_call() task = self._create_sniffing_task(is_initial_sniff) @@ -409,8 +417,7 @@ def _should_sniff(self, is_initial_sniff: bool) -> bool: self._sniffing_task.result() return ( - self._loop.time() - self._last_sniffed_at - >= self._min_delay_between_sniffing + time.monotonic() - self._last_sniffed_at >= self._min_delay_between_sniffing ) def _create_sniffing_task( @@ -429,7 +436,7 @@ async def _sniffing_task_impl(self, is_initial_sniff: bool) -> None: """Implementation of the sniffing task""" previously_sniffed_at = self._last_sniffed_at try: - self._last_sniffed_at = self._loop.time() + self._last_sniffed_at = time.monotonic() options = SniffOptions( is_initial_sniff=is_initial_sniff, sniff_timeout=self._sniff_timeout ) @@ -466,8 +473,13 @@ async def _async_call(self) -> None: because we're not guaranteed to be within an active asyncio event loop when __init__() is called. """ - if self._loop is not None: + if self._async_library is not None: return # Call at most once! + + self._async_library = sniffio.current_async_library() + if self._async_library == "trio": + return + self._loop = asyncio.get_running_loop() if self._sniff_on_start: await self.sniff(True) diff --git a/elastic_transport/_node/_http_aiohttp.py b/elastic_transport/_node/_http_aiohttp.py index 5ed1700..5fe69f7 100644 --- a/elastic_transport/_node/_http_aiohttp.py +++ b/elastic_transport/_node/_http_aiohttp.py @@ -72,7 +72,10 @@ class RequestKwarg(TypedDict, total=False): class AiohttpHttpNode(BaseAsyncNode): - """Default asynchronous node class using the ``aiohttp`` library via HTTP""" + """Default asynchronous node class using the ``aiohttp`` library via HTTP. + + Supports asyncio. + """ _CLIENT_META_HTTP_CLIENT = ("ai", _AIOHTTP_META_VERSION) diff --git a/elastic_transport/_node/_http_httpx.py b/elastic_transport/_node/_http_httpx.py index 04ceb60..2ac6de2 100644 --- a/elastic_transport/_node/_http_httpx.py +++ b/elastic_transport/_node/_http_httpx.py @@ -46,6 +46,10 @@ class HttpxAsyncHttpNode(BaseAsyncNode): + """ + Async HTTP node using httpx. Supports both Trio and asyncio. + """ + _CLIENT_META_HTTP_CLIENT = ("hx", _HTTPX_META_VERSION) def __init__(self, config: NodeConfig): @@ -175,11 +179,11 @@ async def perform_request( # type: ignore[override] body=body, exception=err, ) - raise err from None + raise err from e meta = ApiResponseMeta( resp.status_code, - resp.http_version, + resp.http_version.lstrip("HTTP/"), HttpHeaders(resp.headers), duration, self.config, diff --git a/setup.py b/setup.py index 2183271..208521f 100644 --- a/setup.py +++ b/setup.py @@ -52,6 +52,7 @@ install_requires=[ "urllib3>=1.26.2, <3", "certifi", + "sniffio", ], python_requires=">=3.8", extras_require={ @@ -70,6 +71,8 @@ "opentelemetry-api", "opentelemetry-sdk", "orjson", + "anyio", + "trio", # Override Read the Docs default (sphinx<2) "sphinx>2", "furo", diff --git a/tests/async_/test_async_transport.py b/tests/async_/test_async_transport.py index 24a869c..4ed25ee 100644 --- a/tests/async_/test_async_transport.py +++ b/tests/async_/test_async_transport.py @@ -24,7 +24,9 @@ import warnings from unittest import mock +import anyio import pytest +import sniffio from elastic_transport import ( AiohttpHttpNode, @@ -45,9 +47,11 @@ from tests.conftest import AsyncDummyNode -@pytest.mark.asyncio +@pytest.mark.anyio async def test_async_transport_httpbin(httpbin_node_config, httpbin): - t = AsyncTransport([httpbin_node_config], meta_header=False) + t = AsyncTransport( + [httpbin_node_config], meta_header=False, node_class=HttpxAsyncHttpNode + ) resp, data = await t.perform_request("GET", "/anything?key=value") assert resp.status == 200 @@ -57,6 +61,8 @@ async def test_async_transport_httpbin(httpbin_node_config, httpbin): data["headers"].pop("X-Amzn-Trace-Id", None) assert data["headers"] == { + "Accept": "*/*", + "Accept-Encoding": "gzip, deflate, br", "User-Agent": DEFAULT_USER_AGENT, "Connection": "keep-alive", "Host": f"{httpbin.host}:{httpbin.port}", @@ -66,7 +72,7 @@ async def test_async_transport_httpbin(httpbin_node_config, httpbin): @pytest.mark.skipif( sys.version_info < (3, 8), reason="Mock didn't support async before Python 3.8" ) -@pytest.mark.asyncio +@pytest.mark.anyio async def test_transport_close_node_pool(): t = AsyncTransport([NodeConfig("http", "localhost", 443)]) with mock.patch.object(t.node_pool.all()[0], "close") as node_close: @@ -74,7 +80,7 @@ async def test_transport_close_node_pool(): node_close.assert_called_with() -@pytest.mark.asyncio +@pytest.mark.anyio async def test_request_with_custom_user_agent_header(): t = AsyncTransport( [NodeConfig("http", "localhost", 80)], @@ -91,7 +97,7 @@ async def test_request_with_custom_user_agent_header(): } == t.node_pool.get().calls[0][1] -@pytest.mark.asyncio +@pytest.mark.anyio async def test_body_gets_encoded_into_bytes(): t = AsyncTransport([NodeConfig("http", "localhost", 80)], node_class=AsyncDummyNode) @@ -105,7 +111,7 @@ async def test_body_gets_encoded_into_bytes(): assert kwargs["body"] == b'{"key":"\xe4\xbd\xa0\xe5\xa5\xbd"}' -@pytest.mark.asyncio +@pytest.mark.anyio async def test_body_bytes_get_passed_untouched(): t = AsyncTransport([NodeConfig("http", "localhost", 80)], node_class=AsyncDummyNode) @@ -131,7 +137,7 @@ def test_kwargs_passed_on_to_node_pool(): assert dt is t.node_pool.max_dead_node_backoff -@pytest.mark.asyncio +@pytest.mark.anyio async def test_request_will_fail_after_x_retries(): t = AsyncTransport( [ @@ -154,7 +160,7 @@ async def test_request_will_fail_after_x_retries(): @pytest.mark.parametrize("retry_on_timeout", [True, False]) -@pytest.mark.asyncio +@pytest.mark.anyio async def test_retry_on_timeout(retry_on_timeout): t = AsyncTransport( [ @@ -189,7 +195,7 @@ async def test_retry_on_timeout(retry_on_timeout): assert len(e.value.errors) == 0 -@pytest.mark.asyncio +@pytest.mark.anyio async def test_retry_on_status(): t = AsyncTransport( [ @@ -233,7 +239,7 @@ async def test_retry_on_status(): ] -@pytest.mark.asyncio +@pytest.mark.anyio async def test_failed_connection_will_be_marked_as_dead(): t = AsyncTransport( [ @@ -262,7 +268,7 @@ async def test_failed_connection_will_be_marked_as_dead(): assert all(isinstance(error, ConnectionError) for error in e.value.errors) -@pytest.mark.asyncio +@pytest.mark.anyio async def test_resurrected_connection_will_be_marked_as_live_on_success(): for method in ("GET", "HEAD"): t = AsyncTransport( @@ -283,7 +289,7 @@ async def test_resurrected_connection_will_be_marked_as_live_on_success(): assert 1 == len(t.node_pool._dead_nodes.queue) -@pytest.mark.asyncio +@pytest.mark.anyio async def test_mark_dead_error_doesnt_raise(): t = AsyncTransport( [ @@ -303,7 +309,7 @@ async def test_mark_dead_error_doesnt_raise(): mark_dead.assert_called_with(bad_node) -@pytest.mark.asyncio +@pytest.mark.anyio async def test_node_class_as_string(): t = AsyncTransport([NodeConfig("http", "localhost", 80)], node_class="aiohttp") assert isinstance(t.node_pool.get(), AiohttpHttpNode) @@ -320,7 +326,7 @@ async def test_node_class_as_string(): @pytest.mark.parametrize(["status", "boolean"], [(200, True), (299, True)]) -@pytest.mark.asyncio +@pytest.mark.anyio async def test_head_response_true(status, boolean): t = AsyncTransport( [NodeConfig("http", "localhost", 80, _extras={"status": status, "body": b""})], @@ -331,7 +337,7 @@ async def test_head_response_true(status, boolean): assert data is None -@pytest.mark.asyncio +@pytest.mark.anyio async def test_head_response_false(): t = AsyncTransport( [NodeConfig("http", "localhost", 80, _extras={"status": 404, "body": b""})], @@ -353,7 +359,7 @@ async def test_head_response_false(): (HttpxAsyncHttpNode, "hx"), ], ) -@pytest.mark.asyncio +@pytest.mark.anyio async def test_transport_client_meta_node_class(node_class, client_short_name): t = AsyncTransport([NodeConfig("http", "localhost", 80)], node_class=node_class) assert ( @@ -366,7 +372,7 @@ async def test_transport_client_meta_node_class(node_class, client_short_name): ) -@pytest.mark.asyncio +@pytest.mark.anyio async def test_transport_default_client_meta_node_class(): # Defaults to aiohttp t = AsyncTransport( @@ -635,7 +641,7 @@ async def test_sniff_on_start_no_results_errors(sniff_callback): @pytest.mark.parametrize("pool_size", [1, 8]) -@pytest.mark.asyncio +@pytest.mark.anyio async def test_multiple_tasks_test(pool_size): node_configs = [ NodeConfig("http", "localhost", 80), @@ -648,34 +654,43 @@ async def sniff_callback(*_): await asyncio.sleep(random.random()) return node_configs + kwargs = {} + if sniffio.current_async_library() == "asyncio": + kwargs = { + "sniff_on_start": True, + "sniff_before_requests": True, + "sniff_on_node_failure": True, + "sniff_callback": sniff_callback, + } + t = AsyncTransport( node_configs, retry_on_status=[500], max_retries=5, node_class=AsyncDummyNode, - sniff_on_start=True, - sniff_before_requests=True, - sniff_on_node_failure=True, - sniff_callback=sniff_callback, + **kwargs, ) - loop = asyncio.get_running_loop() - start = loop.time() + start = time.monotonic() + + successful_requests = 0 async def run_requests(): - successful_requests = 0 - while loop.time() - start < 2: + nonlocal successful_requests + while time.monotonic() - start < 2: await t.perform_request("GET", "/") successful_requests += 1 return successful_requests - tasks = [loop.create_task(run_requests()) for _ in range(pool_size * 2)] - assert sum([await task for task in tasks]) >= 1000 + async with anyio.create_task_group() as tg: + for _ in range(pool_size * 2): + tg.start_soon(run_requests) + assert successful_requests >= 1000 -@pytest.mark.asyncio +@pytest.mark.anyio async def test_httpbin(httpbin_node_config): - t = AsyncTransport([httpbin_node_config]) + t = AsyncTransport([httpbin_node_config], node_class=HttpxAsyncHttpNode) resp = await t.perform_request("GET", "/anything") assert resp.meta.status == 200 assert isinstance(resp.body, dict) diff --git a/tests/async_/test_httpbin.py b/tests/async_/test_httpbin.py index f6cc747..9113e80 100644 --- a/tests/async_/test_httpbin.py +++ b/tests/async_/test_httpbin.py @@ -20,15 +20,15 @@ import pytest -from elastic_transport import AiohttpHttpNode, AsyncTransport +from elastic_transport import AsyncTransport, HttpxAsyncHttpNode from elastic_transport._node._base import DEFAULT_USER_AGENT from ..test_httpbin import parse_httpbin -@pytest.mark.asyncio +@pytest.mark.anyio async def test_simple_request(httpbin_node_config, httpbin): - t = AsyncTransport([httpbin_node_config]) + t = AsyncTransport([httpbin_node_config], node_class=HttpxAsyncHttpNode) resp, data = await t.perform_request( "GET", @@ -59,10 +59,10 @@ async def test_simple_request(httpbin_node_config, httpbin): assert all(v == data["headers"][k] for k, v in request_headers.items()) -@pytest.mark.asyncio +@pytest.mark.anyio async def test_node(httpbin_node_config, httpbin): def new_node(**kwargs): - return AiohttpHttpNode(dataclasses.replace(httpbin_node_config, **kwargs)) + return HttpxAsyncHttpNode(dataclasses.replace(httpbin_node_config, **kwargs)) node = new_node() resp, data = await node.perform_request("GET", "/anything") @@ -70,6 +70,8 @@ def new_node(**kwargs): parsed = parse_httpbin(data) assert parsed == { "headers": { + "Accept": "*/*", + "Accept-Encoding": "gzip, deflate, br", "Connection": "keep-alive", "Host": f"{httpbin.host}:{httpbin.port}", "User-Agent": DEFAULT_USER_AGENT, @@ -84,6 +86,7 @@ def new_node(**kwargs): parsed = parse_httpbin(data) assert parsed == { "headers": { + "Accept": "*/*", "Accept-Encoding": "gzip", "Connection": "keep-alive", "Host": f"{httpbin.host}:{httpbin.port}", @@ -98,9 +101,9 @@ def new_node(**kwargs): parsed = parse_httpbin(data) assert parsed == { "headers": { + "Accept": "*/*", "Accept-Encoding": "gzip", "Content-Encoding": "gzip", - "Content-Type": "application/octet-stream", "Content-Length": "33", "Connection": "keep-alive", "Host": f"{httpbin.host}:{httpbin.port}", @@ -120,6 +123,7 @@ def new_node(**kwargs): parsed = parse_httpbin(data) assert parsed == { "headers": { + "Accept": "*/*", "Accept-Encoding": "gzip", "Content-Encoding": "gzip", "Content-Length": "36", diff --git a/tests/node/test_http_httpx.py b/tests/node/test_http_httpx.py index ce6e7f4..da11404 100644 --- a/tests/node/test_http_httpx.py +++ b/tests/node/test_http_httpx.py @@ -82,7 +82,7 @@ def test_ca_certs_with_verify_ssl_false_raises_error(self): ) -@pytest.mark.asyncio +@pytest.mark.anyio class TestHttpxAsyncNode: @respx.mock async def test_simple_request(self): diff --git a/tests/node/test_tls_versions.py b/tests/node/test_tls_versions.py index e687d9f..71f7c35 100644 --- a/tests/node/test_tls_versions.py +++ b/tests/node/test_tls_versions.py @@ -23,6 +23,7 @@ from elastic_transport import ( AiohttpHttpNode, + ConnectionError, HttpxAsyncHttpNode, NodeConfig, RequestsHttpNode, @@ -98,10 +99,15 @@ def tlsv1_1_supported() -> bool: ["url", "ssl_version"], supported_version_params, ) -@pytest.mark.asyncio -async def test_supported_tls_versions(node_class, url: str, ssl_version: int): +@pytest.mark.anyio +async def test_supported_tls_versions( + node_class, url: str, ssl_version: int, anyio_backend +): if url in (TLSv1_0_URL, TLSv1_1_URL) and not tlsv1_1_supported(): pytest.skip("TLSv1.1 isn't supported by this OpenSSL distribution") + if anyio_backend == "trio" and node_class is not HttpxAsyncHttpNode: + pytest.skip("only httpx supports trio") + node_config = url_to_node_config(url).replace(ssl_version=ssl_version) node = node_class(node_config) @@ -114,20 +120,31 @@ async def test_supported_tls_versions(node_class, url: str, ssl_version: int): ["url", "ssl_version"], unsupported_version_params, ) -@pytest.mark.asyncio -async def test_unsupported_tls_versions(node_class, url: str, ssl_version: int): +@pytest.mark.anyio +async def test_unsupported_tls_versions( + node_class, url: str, ssl_version: int, anyio_backend +): + if anyio_backend == "trio" and node_class is not HttpxAsyncHttpNode: + pytest.skip("only httpx supports trio") + node_config = url_to_node_config(url).replace(ssl_version=ssl_version) node = node_class(node_config) - with pytest.raises(TlsError) as e: + # Remove ConnectionError when we have a fix or workaround for + # https://github.com/encode/httpx/discussions/3674 + with pytest.raises((TlsError, ConnectionError)) as e: await await_if_coro(node.perform_request("GET", "/")) + if anyio_backend == "trio" and node_class is HttpxAsyncHttpNode: + return # Trio errors are not correctly bubbled up by httpx assert "unsupported protocol" in str(e.value) or "handshake failure" in str(e.value) @node_classes @pytest.mark.parametrize("ssl_version", [0, "TLSv1", object()]) def test_ssl_version_value_error(node_class, ssl_version): - with pytest.raises(ValueError) as e: + # Remove ConnectionError when we have a fix or workaround for + # https://github.com/encode/httpx/discussions/3674 + with pytest.raises((ValueError, ConnectionError)) as e: node_class(NodeConfig("https", "localhost", 9200, ssl_version=ssl_version)) assert str(e.value) == ( f"Unsupported value for 'ssl_version': {ssl_version!r}. Must be either " diff --git a/tests/test_logging.py b/tests/test_logging.py index 98e084c..24fd5d7 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -24,6 +24,7 @@ AiohttpHttpNode, ConnectionError, HttpHeaders, + HttpxAsyncHttpNode, RequestsHttpNode, Urllib3HttpNode, debug_logging, @@ -32,13 +33,17 @@ from elastic_transport._node._base import DEFAULT_USER_AGENT node_class = pytest.mark.parametrize( - "node_class", [Urllib3HttpNode, RequestsHttpNode, AiohttpHttpNode] + "node_class", + [Urllib3HttpNode, RequestsHttpNode, AiohttpHttpNode, HttpxAsyncHttpNode], ) @node_class -@pytest.mark.asyncio -async def test_debug_logging(node_class, httpbin_node_config, httpbin): +@pytest.mark.anyio +async def test_debug_logging(node_class, anyio_backend, httpbin_node_config, httpbin): + if anyio_backend == "trio" and node_class is not HttpxAsyncHttpNode: + pytest.skip("only httpx supports trio") + debug_logging() stream = io.StringIO() @@ -92,8 +97,13 @@ async def test_debug_logging(node_class, httpbin_node_config, httpbin): @node_class -@pytest.mark.asyncio -async def test_debug_logging_uncompressed_body(httpbin_node_config, node_class): +@pytest.mark.anyio +async def test_debug_logging_uncompressed_body( + httpbin_node_config, node_class, anyio_backend +): + if anyio_backend == "trio" and node_class is not HttpxAsyncHttpNode: + pytest.skip("only httpx supports trio") + debug_logging() stream = io.StringIO() logging.getLogger("elastic_transport.node").addHandler( @@ -116,8 +126,11 @@ async def test_debug_logging_uncompressed_body(httpbin_node_config, node_class): @node_class -@pytest.mark.asyncio -async def test_debug_logging_no_body(httpbin_node_config, node_class): +@pytest.mark.anyio +async def test_debug_logging_no_body(httpbin_node_config, node_class, anyio_backend): + if anyio_backend == "trio" and node_class is not HttpxAsyncHttpNode: + pytest.skip("only httpx supports trio") + debug_logging() stream = io.StringIO() logging.getLogger("elastic_transport.node").addHandler( @@ -137,8 +150,11 @@ async def test_debug_logging_no_body(httpbin_node_config, node_class): @node_class -@pytest.mark.asyncio -async def test_debug_logging_error(httpbin_node_config, node_class): +@pytest.mark.anyio +async def test_debug_logging_error(httpbin_node_config, node_class, anyio_backend): + if anyio_backend == "trio" and node_class is not HttpxAsyncHttpNode: + pytest.skip("only httpx supports trio") + debug_logging() stream = io.StringIO() logging.getLogger("elastic_transport.node").addHandler(