Skip to content

Commit ee52dcb

Browse files
committed
Support Trio on httpx backend
The main limitation is that sniffing isn't supported, as the way it's currently designed (starting a task in the background and never collecting it) is not compatible with structured concurrency.
1 parent 6da4c83 commit ee52dcb

File tree

8 files changed

+125
-61
lines changed

8 files changed

+125
-61
lines changed

elastic_transport/_async_transport.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
import asyncio
1919
import logging
20+
import time
21+
import sniffio
2022
from typing import (
2123
Any,
2224
Awaitable,
@@ -169,6 +171,7 @@ def __init__(
169171
# time it's needed. Gets set within '_async_call()' which should
170172
# precede all logic within async calls.
171173
self._loop: asyncio.AbstractEventLoop = None # type: ignore[assignment]
174+
self._async_library: str = None # type: ignore[assignment]
172175

173176
# AsyncTransport doesn't require a thread lock for
174177
# sniffing. Uses '_sniffing_task' instead.
@@ -258,7 +261,7 @@ async def perform_request( # type: ignore[override, return]
258261
node_failure = False
259262
last_response: Optional[TransportApiResponse] = None
260263
node: BaseAsyncNode = self.node_pool.get() # type: ignore[assignment]
261-
start_time = self._loop.time()
264+
start_time = time.monotonic()
262265
try:
263266
otel_span.set_node_metadata(
264267
node.host, node.port, node.base_url, target, method
@@ -277,7 +280,7 @@ async def perform_request( # type: ignore[override, return]
277280
node.base_url,
278281
target,
279282
resp.meta.status,
280-
self._loop.time() - start_time,
283+
time.monotonic() - start_time,
281284
)
282285
)
283286

@@ -300,7 +303,7 @@ async def perform_request( # type: ignore[override, return]
300303
node.base_url,
301304
target,
302305
"N/A",
303-
self._loop.time() - start_time,
306+
time.monotonic() - start_time,
304307
)
305308
)
306309

@@ -377,6 +380,10 @@ async def perform_request( # type: ignore[override, return]
377380
)
378381

379382
async def sniff(self, is_initial_sniff: bool = False) -> None: # type: ignore[override]
383+
if sniffio.current_async_library() != "asyncio":
384+
raise ValueError(
385+
f"Asynchronous sniffing only works with the 'asyncio' library, got {sniffio.current_async_library}"
386+
)
380387
await self._async_call()
381388
task = self._create_sniffing_task(is_initial_sniff)
382389

@@ -409,8 +416,7 @@ def _should_sniff(self, is_initial_sniff: bool) -> bool:
409416
self._sniffing_task.result()
410417

411418
return (
412-
self._loop.time() - self._last_sniffed_at
413-
>= self._min_delay_between_sniffing
419+
time.monotonic() - self._last_sniffed_at >= self._min_delay_between_sniffing
414420
)
415421

416422
def _create_sniffing_task(
@@ -429,7 +435,7 @@ async def _sniffing_task_impl(self, is_initial_sniff: bool) -> None:
429435
"""Implementation of the sniffing task"""
430436
previously_sniffed_at = self._last_sniffed_at
431437
try:
432-
self._last_sniffed_at = self._loop.time()
438+
self._last_sniffed_at = time.monotonic()
433439
options = SniffOptions(
434440
is_initial_sniff=is_initial_sniff, sniff_timeout=self._sniff_timeout
435441
)
@@ -466,8 +472,13 @@ async def _async_call(self) -> None:
466472
because we're not guaranteed to be within an active asyncio event loop
467473
when __init__() is called.
468474
"""
469-
if self._loop is not None:
475+
if self._async_library is not None:
470476
return # Call at most once!
477+
478+
self._async_library = sniffio.current_async_library()
479+
if self._async_library != "asyncio":
480+
return
481+
471482
self._loop = asyncio.get_running_loop()
472483
if self._sniff_on_start:
473484
await self.sniff(True)

elastic_transport/_node/_http_httpx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,11 +175,11 @@ async def perform_request( # type: ignore[override]
175175
body=body,
176176
exception=err,
177177
)
178-
raise err from None
178+
raise err from e
179179

180180
meta = ApiResponseMeta(
181181
resp.status_code,
182-
resp.http_version,
182+
resp.http_version.lstrip("HTTP/"),
183183
HttpHeaders(resp.headers),
184184
duration,
185185
self.config,

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@
5252
install_requires=[
5353
"urllib3>=1.26.2, <3",
5454
"certifi",
55+
"sniffio",
56+
"anyio",
5557
],
5658
python_requires=">=3.8",
5759
extras_require={

tests/async_/test_async_transport.py

Lines changed: 48 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
import warnings
2525
from unittest import mock
2626

27+
import anyio
28+
import sniffio
29+
2730
import pytest
2831

2932
from elastic_transport import (
@@ -45,9 +48,11 @@
4548
from tests.conftest import AsyncDummyNode
4649

4750

48-
@pytest.mark.asyncio
51+
@pytest.mark.anyio
4952
async def test_async_transport_httpbin(httpbin_node_config, httpbin):
50-
t = AsyncTransport([httpbin_node_config], meta_header=False)
53+
t = AsyncTransport(
54+
[httpbin_node_config], meta_header=False, node_class=HttpxAsyncHttpNode
55+
)
5156
resp, data = await t.perform_request("GET", "/anything?key=value")
5257

5358
assert resp.status == 200
@@ -57,6 +62,8 @@ async def test_async_transport_httpbin(httpbin_node_config, httpbin):
5762

5863
data["headers"].pop("X-Amzn-Trace-Id", None)
5964
assert data["headers"] == {
65+
"Accept": "*/*",
66+
"Accept-Encoding": "gzip, deflate, br",
6067
"User-Agent": DEFAULT_USER_AGENT,
6168
"Connection": "keep-alive",
6269
"Host": f"{httpbin.host}:{httpbin.port}",
@@ -66,15 +73,15 @@ async def test_async_transport_httpbin(httpbin_node_config, httpbin):
6673
@pytest.mark.skipif(
6774
sys.version_info < (3, 8), reason="Mock didn't support async before Python 3.8"
6875
)
69-
@pytest.mark.asyncio
76+
@pytest.mark.anyio
7077
async def test_transport_close_node_pool():
7178
t = AsyncTransport([NodeConfig("http", "localhost", 443)])
7279
with mock.patch.object(t.node_pool.all()[0], "close") as node_close:
7380
await t.close()
7481
node_close.assert_called_with()
7582

7683

77-
@pytest.mark.asyncio
84+
@pytest.mark.anyio
7885
async def test_request_with_custom_user_agent_header():
7986
t = AsyncTransport(
8087
[NodeConfig("http", "localhost", 80)],
@@ -91,7 +98,7 @@ async def test_request_with_custom_user_agent_header():
9198
} == t.node_pool.get().calls[0][1]
9299

93100

94-
@pytest.mark.asyncio
101+
@pytest.mark.anyio
95102
async def test_body_gets_encoded_into_bytes():
96103
t = AsyncTransport([NodeConfig("http", "localhost", 80)], node_class=AsyncDummyNode)
97104

@@ -105,7 +112,7 @@ async def test_body_gets_encoded_into_bytes():
105112
assert kwargs["body"] == b'{"key":"\xe4\xbd\xa0\xe5\xa5\xbd"}'
106113

107114

108-
@pytest.mark.asyncio
115+
@pytest.mark.anyio
109116
async def test_body_bytes_get_passed_untouched():
110117
t = AsyncTransport([NodeConfig("http", "localhost", 80)], node_class=AsyncDummyNode)
111118

@@ -131,7 +138,7 @@ def test_kwargs_passed_on_to_node_pool():
131138
assert dt is t.node_pool.max_dead_node_backoff
132139

133140

134-
@pytest.mark.asyncio
141+
@pytest.mark.anyio
135142
async def test_request_will_fail_after_x_retries():
136143
t = AsyncTransport(
137144
[
@@ -154,7 +161,7 @@ async def test_request_will_fail_after_x_retries():
154161

155162

156163
@pytest.mark.parametrize("retry_on_timeout", [True, False])
157-
@pytest.mark.asyncio
164+
@pytest.mark.anyio
158165
async def test_retry_on_timeout(retry_on_timeout):
159166
t = AsyncTransport(
160167
[
@@ -189,7 +196,7 @@ async def test_retry_on_timeout(retry_on_timeout):
189196
assert len(e.value.errors) == 0
190197

191198

192-
@pytest.mark.asyncio
199+
@pytest.mark.anyio
193200
async def test_retry_on_status():
194201
t = AsyncTransport(
195202
[
@@ -233,7 +240,7 @@ async def test_retry_on_status():
233240
]
234241

235242

236-
@pytest.mark.asyncio
243+
@pytest.mark.anyio
237244
async def test_failed_connection_will_be_marked_as_dead():
238245
t = AsyncTransport(
239246
[
@@ -262,7 +269,7 @@ async def test_failed_connection_will_be_marked_as_dead():
262269
assert all(isinstance(error, ConnectionError) for error in e.value.errors)
263270

264271

265-
@pytest.mark.asyncio
272+
@pytest.mark.anyio
266273
async def test_resurrected_connection_will_be_marked_as_live_on_success():
267274
for method in ("GET", "HEAD"):
268275
t = AsyncTransport(
@@ -283,7 +290,7 @@ async def test_resurrected_connection_will_be_marked_as_live_on_success():
283290
assert 1 == len(t.node_pool._dead_nodes.queue)
284291

285292

286-
@pytest.mark.asyncio
293+
@pytest.mark.anyio
287294
async def test_mark_dead_error_doesnt_raise():
288295
t = AsyncTransport(
289296
[
@@ -303,7 +310,7 @@ async def test_mark_dead_error_doesnt_raise():
303310
mark_dead.assert_called_with(bad_node)
304311

305312

306-
@pytest.mark.asyncio
313+
@pytest.mark.anyio
307314
async def test_node_class_as_string():
308315
t = AsyncTransport([NodeConfig("http", "localhost", 80)], node_class="aiohttp")
309316
assert isinstance(t.node_pool.get(), AiohttpHttpNode)
@@ -320,7 +327,7 @@ async def test_node_class_as_string():
320327

321328

322329
@pytest.mark.parametrize(["status", "boolean"], [(200, True), (299, True)])
323-
@pytest.mark.asyncio
330+
@pytest.mark.anyio
324331
async def test_head_response_true(status, boolean):
325332
t = AsyncTransport(
326333
[NodeConfig("http", "localhost", 80, _extras={"status": status, "body": b""})],
@@ -331,7 +338,7 @@ async def test_head_response_true(status, boolean):
331338
assert data is None
332339

333340

334-
@pytest.mark.asyncio
341+
@pytest.mark.anyio
335342
async def test_head_response_false():
336343
t = AsyncTransport(
337344
[NodeConfig("http", "localhost", 80, _extras={"status": 404, "body": b""})],
@@ -353,7 +360,7 @@ async def test_head_response_false():
353360
(HttpxAsyncHttpNode, "hx"),
354361
],
355362
)
356-
@pytest.mark.asyncio
363+
@pytest.mark.anyio
357364
async def test_transport_client_meta_node_class(node_class, client_short_name):
358365
t = AsyncTransport([NodeConfig("http", "localhost", 80)], node_class=node_class)
359366
assert (
@@ -366,7 +373,7 @@ async def test_transport_client_meta_node_class(node_class, client_short_name):
366373
)
367374

368375

369-
@pytest.mark.asyncio
376+
@pytest.mark.anyio
370377
async def test_transport_default_client_meta_node_class():
371378
# Defaults to aiohttp
372379
t = AsyncTransport(
@@ -635,7 +642,7 @@ async def test_sniff_on_start_no_results_errors(sniff_callback):
635642

636643

637644
@pytest.mark.parametrize("pool_size", [1, 8])
638-
@pytest.mark.asyncio
645+
@pytest.mark.anyio
639646
async def test_multiple_tasks_test(pool_size):
640647
node_configs = [
641648
NodeConfig("http", "localhost", 80),
@@ -648,34 +655,45 @@ async def sniff_callback(*_):
648655
await asyncio.sleep(random.random())
649656
return node_configs
650657

658+
kwargs = {}
659+
if sniffio.current_async_library() == "asyncio":
660+
kwargs = {
661+
"sniff_on_start": True,
662+
"sniff_before_requests": True,
663+
"sniff_on_node_failure": True,
664+
"sniff_callback": sniff_callback,
665+
}
666+
667+
print(kwargs)
668+
651669
t = AsyncTransport(
652670
node_configs,
653671
retry_on_status=[500],
654672
max_retries=5,
655673
node_class=AsyncDummyNode,
656-
sniff_on_start=True,
657-
sniff_before_requests=True,
658-
sniff_on_node_failure=True,
659-
sniff_callback=sniff_callback,
674+
**kwargs,
660675
)
661676

662-
loop = asyncio.get_running_loop()
663-
start = loop.time()
677+
start = time.monotonic()
678+
679+
successful_requests = 0
664680

665681
async def run_requests():
666-
successful_requests = 0
667-
while loop.time() - start < 2:
682+
nonlocal successful_requests
683+
while time.monotonic() - start < 2:
668684
await t.perform_request("GET", "/")
669685
successful_requests += 1
670686
return successful_requests
671687

672-
tasks = [loop.create_task(run_requests()) for _ in range(pool_size * 2)]
673-
assert sum([await task for task in tasks]) >= 1000
688+
async with anyio.create_task_group() as tg:
689+
for _ in range(pool_size * 2):
690+
tg.start_soon(run_requests)
691+
assert successful_requests >= 1000
674692

675693

676-
@pytest.mark.asyncio
694+
@pytest.mark.anyio
677695
async def test_httpbin(httpbin_node_config):
678-
t = AsyncTransport([httpbin_node_config])
696+
t = AsyncTransport([httpbin_node_config], node_class=HttpxAsyncHttpNode)
679697
resp = await t.perform_request("GET", "/anything")
680698
assert resp.meta.status == 200
681699
assert isinstance(resp.body, dict)

0 commit comments

Comments
 (0)