-
Notifications
You must be signed in to change notification settings - Fork 18
Support Trio with httpx #263
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
ee52dcb
e8fffbc
940d296
17ef8fa
573cae1
6da5464
76722f0
83bdbf9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
|
|
||
| import asyncio | ||
| import logging | ||
| import time | ||
| from typing import ( | ||
| Any, | ||
| Awaitable, | ||
|
|
@@ -30,6 +31,8 @@ | |
| Union, | ||
| ) | ||
|
|
||
| import sniffio | ||
|
|
||
| from ._compat import await_if_coro | ||
| from ._exceptions import ( | ||
| ConnectionError, | ||
|
|
@@ -169,6 +172,7 @@ def __init__( | |
| # time it's needed. Gets set within '_async_call()' which should | ||
| # precede all logic within async calls. | ||
| self._loop: asyncio.AbstractEventLoop = None # type: ignore[assignment] | ||
| self._async_library: str = None # type: ignore[assignment] | ||
|
|
||
| # AsyncTransport doesn't require a thread lock for | ||
| # sniffing. Uses '_sniffing_task' instead. | ||
|
|
@@ -258,7 +262,7 @@ async def perform_request( # type: ignore[override, return] | |
| node_failure = False | ||
| last_response: Optional[TransportApiResponse] = None | ||
| node: BaseAsyncNode = self.node_pool.get() # type: ignore[assignment] | ||
| start_time = self._loop.time() | ||
| start_time = time.monotonic() | ||
| try: | ||
| otel_span.set_node_metadata( | ||
| node.host, node.port, node.base_url, target, method | ||
|
|
@@ -277,7 +281,7 @@ async def perform_request( # type: ignore[override, return] | |
| node.base_url, | ||
| target, | ||
| resp.meta.status, | ||
| self._loop.time() - start_time, | ||
| time.monotonic() - start_time, | ||
| ) | ||
| ) | ||
|
|
||
|
|
@@ -300,7 +304,7 @@ async def perform_request( # type: ignore[override, return] | |
| node.base_url, | ||
| target, | ||
| "N/A", | ||
| self._loop.time() - start_time, | ||
| time.monotonic() - start_time, | ||
| ) | ||
| ) | ||
|
|
||
|
|
@@ -377,6 +381,10 @@ async def perform_request( # type: ignore[override, return] | |
| ) | ||
|
|
||
| async def sniff(self, is_initial_sniff: bool = False) -> None: # type: ignore[override] | ||
| if sniffio.current_async_library() != "asyncio": | ||
| raise ValueError( | ||
| f"Asynchronous sniffing only works with the 'asyncio' library, got {sniffio.current_async_library}" | ||
| ) | ||
| await self._async_call() | ||
| task = self._create_sniffing_task(is_initial_sniff) | ||
|
|
||
|
|
@@ -409,8 +417,7 @@ def _should_sniff(self, is_initial_sniff: bool) -> bool: | |
| self._sniffing_task.result() | ||
|
|
||
| return ( | ||
| self._loop.time() - self._last_sniffed_at | ||
| >= self._min_delay_between_sniffing | ||
| time.monotonic() - self._last_sniffed_at >= self._min_delay_between_sniffing | ||
| ) | ||
|
|
||
| def _create_sniffing_task( | ||
|
|
@@ -429,7 +436,7 @@ async def _sniffing_task_impl(self, is_initial_sniff: bool) -> None: | |
| """Implementation of the sniffing task""" | ||
| previously_sniffed_at = self._last_sniffed_at | ||
| try: | ||
| self._last_sniffed_at = self._loop.time() | ||
| self._last_sniffed_at = time.monotonic() | ||
| options = SniffOptions( | ||
| is_initial_sniff=is_initial_sniff, sniff_timeout=self._sniff_timeout | ||
| ) | ||
|
|
@@ -466,8 +473,13 @@ async def _async_call(self) -> None: | |
| because we're not guaranteed to be within an active asyncio event loop | ||
| when __init__() is called. | ||
| """ | ||
| if self._loop is not None: | ||
| if self._async_library is not None: | ||
| return # Call at most once! | ||
|
|
||
| self._async_library = sniffio.current_async_library() | ||
| if self._async_library != "asyncio": | ||
|
||
| return | ||
|
|
||
| self._loop = asyncio.get_running_loop() | ||
| if self._sniff_on_start: | ||
| await self.sniff(True) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,7 +24,9 @@ | |
| import warnings | ||
| from unittest import mock | ||
|
|
||
| import anyio | ||
| import pytest | ||
| import sniffio | ||
|
|
||
| from elastic_transport import ( | ||
| AiohttpHttpNode, | ||
|
|
@@ -45,9 +47,11 @@ | |
| from tests.conftest import AsyncDummyNode | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| @pytest.mark.anyio | ||
| async def test_async_transport_httpbin(httpbin_node_config, httpbin): | ||
| t = AsyncTransport([httpbin_node_config], meta_header=False) | ||
| t = AsyncTransport( | ||
| [httpbin_node_config], meta_header=False, node_class=HttpxAsyncHttpNode | ||
| ) | ||
| resp, data = await t.perform_request("GET", "/anything?key=value") | ||
|
|
||
| assert resp.status == 200 | ||
|
|
@@ -57,6 +61,8 @@ async def test_async_transport_httpbin(httpbin_node_config, httpbin): | |
|
|
||
| data["headers"].pop("X-Amzn-Trace-Id", None) | ||
| assert data["headers"] == { | ||
| "Accept": "*/*", | ||
| "Accept-Encoding": "gzip, deflate, br", | ||
| "User-Agent": DEFAULT_USER_AGENT, | ||
| "Connection": "keep-alive", | ||
| "Host": f"{httpbin.host}:{httpbin.port}", | ||
|
|
@@ -66,15 +72,15 @@ async def test_async_transport_httpbin(httpbin_node_config, httpbin): | |
| @pytest.mark.skipif( | ||
| sys.version_info < (3, 8), reason="Mock didn't support async before Python 3.8" | ||
| ) | ||
| @pytest.mark.asyncio | ||
| @pytest.mark.anyio | ||
| async def test_transport_close_node_pool(): | ||
| t = AsyncTransport([NodeConfig("http", "localhost", 443)]) | ||
| with mock.patch.object(t.node_pool.all()[0], "close") as node_close: | ||
| await t.close() | ||
| node_close.assert_called_with() | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| @pytest.mark.anyio | ||
| async def test_request_with_custom_user_agent_header(): | ||
| t = AsyncTransport( | ||
| [NodeConfig("http", "localhost", 80)], | ||
|
|
@@ -91,7 +97,7 @@ async def test_request_with_custom_user_agent_header(): | |
| } == t.node_pool.get().calls[0][1] | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| @pytest.mark.anyio | ||
| async def test_body_gets_encoded_into_bytes(): | ||
| t = AsyncTransport([NodeConfig("http", "localhost", 80)], node_class=AsyncDummyNode) | ||
|
|
||
|
|
@@ -105,7 +111,7 @@ async def test_body_gets_encoded_into_bytes(): | |
| assert kwargs["body"] == b'{"key":"\xe4\xbd\xa0\xe5\xa5\xbd"}' | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| @pytest.mark.anyio | ||
| async def test_body_bytes_get_passed_untouched(): | ||
| t = AsyncTransport([NodeConfig("http", "localhost", 80)], node_class=AsyncDummyNode) | ||
|
|
||
|
|
@@ -131,7 +137,7 @@ def test_kwargs_passed_on_to_node_pool(): | |
| assert dt is t.node_pool.max_dead_node_backoff | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| @pytest.mark.anyio | ||
| async def test_request_will_fail_after_x_retries(): | ||
| t = AsyncTransport( | ||
| [ | ||
|
|
@@ -154,7 +160,7 @@ async def test_request_will_fail_after_x_retries(): | |
|
|
||
|
|
||
| @pytest.mark.parametrize("retry_on_timeout", [True, False]) | ||
| @pytest.mark.asyncio | ||
| @pytest.mark.anyio | ||
| async def test_retry_on_timeout(retry_on_timeout): | ||
| t = AsyncTransport( | ||
| [ | ||
|
|
@@ -189,7 +195,7 @@ async def test_retry_on_timeout(retry_on_timeout): | |
| assert len(e.value.errors) == 0 | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| @pytest.mark.anyio | ||
| async def test_retry_on_status(): | ||
| t = AsyncTransport( | ||
| [ | ||
|
|
@@ -233,7 +239,7 @@ async def test_retry_on_status(): | |
| ] | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| @pytest.mark.anyio | ||
| async def test_failed_connection_will_be_marked_as_dead(): | ||
| t = AsyncTransport( | ||
| [ | ||
|
|
@@ -262,7 +268,7 @@ async def test_failed_connection_will_be_marked_as_dead(): | |
| assert all(isinstance(error, ConnectionError) for error in e.value.errors) | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| @pytest.mark.anyio | ||
| async def test_resurrected_connection_will_be_marked_as_live_on_success(): | ||
| for method in ("GET", "HEAD"): | ||
| t = AsyncTransport( | ||
|
|
@@ -283,7 +289,7 @@ async def test_resurrected_connection_will_be_marked_as_live_on_success(): | |
| assert 1 == len(t.node_pool._dead_nodes.queue) | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| @pytest.mark.anyio | ||
| async def test_mark_dead_error_doesnt_raise(): | ||
| t = AsyncTransport( | ||
| [ | ||
|
|
@@ -303,7 +309,7 @@ async def test_mark_dead_error_doesnt_raise(): | |
| mark_dead.assert_called_with(bad_node) | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| @pytest.mark.anyio | ||
| async def test_node_class_as_string(): | ||
| t = AsyncTransport([NodeConfig("http", "localhost", 80)], node_class="aiohttp") | ||
| assert isinstance(t.node_pool.get(), AiohttpHttpNode) | ||
|
|
@@ -320,7 +326,7 @@ async def test_node_class_as_string(): | |
|
|
||
|
|
||
| @pytest.mark.parametrize(["status", "boolean"], [(200, True), (299, True)]) | ||
| @pytest.mark.asyncio | ||
| @pytest.mark.anyio | ||
| async def test_head_response_true(status, boolean): | ||
| t = AsyncTransport( | ||
| [NodeConfig("http", "localhost", 80, _extras={"status": status, "body": b""})], | ||
|
|
@@ -331,7 +337,7 @@ async def test_head_response_true(status, boolean): | |
| assert data is None | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| @pytest.mark.anyio | ||
| async def test_head_response_false(): | ||
| t = AsyncTransport( | ||
| [NodeConfig("http", "localhost", 80, _extras={"status": 404, "body": b""})], | ||
|
|
@@ -353,7 +359,7 @@ async def test_head_response_false(): | |
| (HttpxAsyncHttpNode, "hx"), | ||
| ], | ||
| ) | ||
| @pytest.mark.asyncio | ||
| @pytest.mark.anyio | ||
| async def test_transport_client_meta_node_class(node_class, client_short_name): | ||
| t = AsyncTransport([NodeConfig("http", "localhost", 80)], node_class=node_class) | ||
| assert ( | ||
|
|
@@ -366,7 +372,7 @@ async def test_transport_client_meta_node_class(node_class, client_short_name): | |
| ) | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| @pytest.mark.anyio | ||
| async def test_transport_default_client_meta_node_class(): | ||
| # Defaults to aiohttp | ||
| t = AsyncTransport( | ||
|
|
@@ -635,7 +641,7 @@ async def test_sniff_on_start_no_results_errors(sniff_callback): | |
|
|
||
|
|
||
| @pytest.mark.parametrize("pool_size", [1, 8]) | ||
| @pytest.mark.asyncio | ||
| @pytest.mark.anyio | ||
| async def test_multiple_tasks_test(pool_size): | ||
| node_configs = [ | ||
| NodeConfig("http", "localhost", 80), | ||
|
|
@@ -648,34 +654,45 @@ async def sniff_callback(*_): | |
| await asyncio.sleep(random.random()) | ||
| return node_configs | ||
|
|
||
| kwargs = {} | ||
| if sniffio.current_async_library() == "asyncio": | ||
| kwargs = { | ||
| "sniff_on_start": True, | ||
| "sniff_before_requests": True, | ||
| "sniff_on_node_failure": True, | ||
| "sniff_callback": sniff_callback, | ||
| } | ||
|
|
||
| print(kwargs) | ||
|
||
|
|
||
| t = AsyncTransport( | ||
| node_configs, | ||
| retry_on_status=[500], | ||
| max_retries=5, | ||
| node_class=AsyncDummyNode, | ||
| sniff_on_start=True, | ||
| sniff_before_requests=True, | ||
| sniff_on_node_failure=True, | ||
| sniff_callback=sniff_callback, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| loop = asyncio.get_running_loop() | ||
| start = loop.time() | ||
| start = time.monotonic() | ||
|
|
||
| successful_requests = 0 | ||
|
|
||
| async def run_requests(): | ||
| successful_requests = 0 | ||
| while loop.time() - start < 2: | ||
| nonlocal successful_requests | ||
| while time.monotonic() - start < 2: | ||
| await t.perform_request("GET", "/") | ||
| successful_requests += 1 | ||
| return successful_requests | ||
|
|
||
| tasks = [loop.create_task(run_requests()) for _ in range(pool_size * 2)] | ||
| assert sum([await task for task in tasks]) >= 1000 | ||
| async with anyio.create_task_group() as tg: | ||
| for _ in range(pool_size * 2): | ||
| tg.start_soon(run_requests) | ||
| assert successful_requests >= 1000 | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| @pytest.mark.anyio | ||
| async def test_httpbin(httpbin_node_config): | ||
| t = AsyncTransport([httpbin_node_config]) | ||
| t = AsyncTransport([httpbin_node_config], node_class=HttpxAsyncHttpNode) | ||
| resp = await t.perform_request("GET", "/anything") | ||
| assert resp.meta.status == 200 | ||
| assert isinstance(resp.body, dict) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't it be more appropriate here to show the error when current_async_library == "trio", since we know for a fact that sniffing does not work with it? We don't know if sniffing works or not with other async libraries, I assume. And also, this function raises an exception when it cannot recognize which library is in use. For backwards compatibility maybe we should allow the sniffing in such a case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What other async library?
current_async_library == triowon't make any differenceI wouldn't see a different library making the same mistakes as asyncio, ie. allowing background tasks. Anyways, Trio has shown how hard it is for a different async library to take off, even with superior design, excellent docs, etc. So, for me, the condition makes sense: today, we only support sniffing with asyncio.
I find it interesting how this is exactly the discussion we're having in elastic/elasticsearch-py#3103 (comment).