Skip to content

Commit 9cb08a5

Browse files
pquentingithub-actions[bot]
authored andcommitted
Support Trio with httpx (#263)
* Support Trio on httpx backend The main limitation is that sniffing isn't supported, as the way it's currently designed (starting a task in the background and never collecting it) is not compatible with structured concurrency. * Fix lint * Add trio to develop dependencies * Remove anyio from runtime dependencies * Clarify asyncio/trio support * Link to HTTPX issue in badssl tests * Remove debug print * Reverse asyncio/trio check per review feedback (cherry picked from commit 33c0af3)
1 parent 6da4c83 commit 9cb08a5

File tree

9 files changed

+136
-62
lines changed

9 files changed

+136
-62
lines changed

elastic_transport/_async_transport.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import asyncio
1919
import logging
20+
import time
2021
from typing import (
2122
Any,
2223
Awaitable,
@@ -30,6 +31,8 @@
3031
Union,
3132
)
3233

34+
import sniffio
35+
3336
from ._compat import await_if_coro
3437
from ._exceptions import (
3538
ConnectionError,
@@ -169,6 +172,7 @@ def __init__(
169172
# time it's needed. Gets set within '_async_call()' which should
170173
# precede all logic within async calls.
171174
self._loop: asyncio.AbstractEventLoop = None # type: ignore[assignment]
175+
self._async_library: str = None # type: ignore[assignment]
172176

173177
# AsyncTransport doesn't require a thread lock for
174178
# sniffing. Uses '_sniffing_task' instead.
@@ -258,7 +262,7 @@ async def perform_request( # type: ignore[override, return]
258262
node_failure = False
259263
last_response: Optional[TransportApiResponse] = None
260264
node: BaseAsyncNode = self.node_pool.get() # type: ignore[assignment]
261-
start_time = self._loop.time()
265+
start_time = time.monotonic()
262266
try:
263267
otel_span.set_node_metadata(
264268
node.host, node.port, node.base_url, target, method
@@ -277,7 +281,7 @@ async def perform_request( # type: ignore[override, return]
277281
node.base_url,
278282
target,
279283
resp.meta.status,
280-
self._loop.time() - start_time,
284+
time.monotonic() - start_time,
281285
)
282286
)
283287

@@ -300,7 +304,7 @@ async def perform_request( # type: ignore[override, return]
300304
node.base_url,
301305
target,
302306
"N/A",
303-
self._loop.time() - start_time,
307+
time.monotonic() - start_time,
304308
)
305309
)
306310

@@ -377,6 +381,10 @@ async def perform_request( # type: ignore[override, return]
377381
)
378382

379383
async def sniff(self, is_initial_sniff: bool = False) -> None: # type: ignore[override]
384+
if sniffio.current_async_library() == "trio":
385+
raise ValueError(
386+
f"Asynchronous sniffing is not supported with the 'trio' library, got {sniffio.current_async_library}"
387+
)
380388
await self._async_call()
381389
task = self._create_sniffing_task(is_initial_sniff)
382390

@@ -409,8 +417,7 @@ def _should_sniff(self, is_initial_sniff: bool) -> bool:
409417
self._sniffing_task.result()
410418

411419
return (
412-
self._loop.time() - self._last_sniffed_at
413-
>= self._min_delay_between_sniffing
420+
time.monotonic() - self._last_sniffed_at >= self._min_delay_between_sniffing
414421
)
415422

416423
def _create_sniffing_task(
@@ -429,7 +436,7 @@ async def _sniffing_task_impl(self, is_initial_sniff: bool) -> None:
429436
"""Implementation of the sniffing task"""
430437
previously_sniffed_at = self._last_sniffed_at
431438
try:
432-
self._last_sniffed_at = self._loop.time()
439+
self._last_sniffed_at = time.monotonic()
433440
options = SniffOptions(
434441
is_initial_sniff=is_initial_sniff, sniff_timeout=self._sniff_timeout
435442
)
@@ -466,8 +473,13 @@ async def _async_call(self) -> None:
466473
because we're not guaranteed to be within an active asyncio event loop
467474
when __init__() is called.
468475
"""
469-
if self._loop is not None:
476+
if self._async_library is not None:
470477
return # Call at most once!
478+
479+
self._async_library = sniffio.current_async_library()
480+
if self._async_library == "trio":
481+
return
482+
471483
self._loop = asyncio.get_running_loop()
472484
if self._sniff_on_start:
473485
await self.sniff(True)

elastic_transport/_node/_http_aiohttp.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,10 @@ class RequestKwarg(TypedDict, total=False):
7272

7373

7474
class AiohttpHttpNode(BaseAsyncNode):
75-
"""Default asynchronous node class using the ``aiohttp`` library via HTTP"""
75+
"""Default asynchronous node class using the ``aiohttp`` library via HTTP.
76+
77+
Supports asyncio.
78+
"""
7679

7780
_CLIENT_META_HTTP_CLIENT = ("ai", _AIOHTTP_META_VERSION)
7881

elastic_transport/_node/_http_httpx.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@
4646

4747

4848
class HttpxAsyncHttpNode(BaseAsyncNode):
49+
"""
50+
Async HTTP node using httpx. Supports both Trio and asyncio.
51+
"""
52+
4953
_CLIENT_META_HTTP_CLIENT = ("hx", _HTTPX_META_VERSION)
5054

5155
def __init__(self, config: NodeConfig):
@@ -175,11 +179,11 @@ async def perform_request( # type: ignore[override]
175179
body=body,
176180
exception=err,
177181
)
178-
raise err from None
182+
raise err from e
179183

180184
meta = ApiResponseMeta(
181185
resp.status_code,
182-
resp.http_version,
186+
resp.http_version.lstrip("HTTP/"),
183187
HttpHeaders(resp.headers),
184188
duration,
185189
self.config,

setup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
install_requires=[
5353
"urllib3>=1.26.2, <3",
5454
"certifi",
55+
"sniffio",
5556
],
5657
python_requires=">=3.8",
5758
extras_require={
@@ -70,6 +71,8 @@
7071
"opentelemetry-api",
7172
"opentelemetry-sdk",
7273
"orjson",
74+
"anyio",
75+
"trio",
7376
# Override Read the Docs default (sphinx<2)
7477
"sphinx>2",
7578
"furo",

tests/async_/test_async_transport.py

Lines changed: 45 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
import warnings
2525
from unittest import mock
2626

27+
import anyio
2728
import pytest
29+
import sniffio
2830

2931
from elastic_transport import (
3032
AiohttpHttpNode,
@@ -45,9 +47,11 @@
4547
from tests.conftest import AsyncDummyNode
4648

4749

48-
@pytest.mark.asyncio
50+
@pytest.mark.anyio
4951
async def test_async_transport_httpbin(httpbin_node_config, httpbin):
50-
t = AsyncTransport([httpbin_node_config], meta_header=False)
52+
t = AsyncTransport(
53+
[httpbin_node_config], meta_header=False, node_class=HttpxAsyncHttpNode
54+
)
5155
resp, data = await t.perform_request("GET", "/anything?key=value")
5256

5357
assert resp.status == 200
@@ -57,6 +61,8 @@ async def test_async_transport_httpbin(httpbin_node_config, httpbin):
5761

5862
data["headers"].pop("X-Amzn-Trace-Id", None)
5963
assert data["headers"] == {
64+
"Accept": "*/*",
65+
"Accept-Encoding": "gzip, deflate, br",
6066
"User-Agent": DEFAULT_USER_AGENT,
6167
"Connection": "keep-alive",
6268
"Host": f"{httpbin.host}:{httpbin.port}",
@@ -66,15 +72,15 @@ async def test_async_transport_httpbin(httpbin_node_config, httpbin):
6672
@pytest.mark.skipif(
6773
sys.version_info < (3, 8), reason="Mock didn't support async before Python 3.8"
6874
)
69-
@pytest.mark.asyncio
75+
@pytest.mark.anyio
7076
async def test_transport_close_node_pool():
7177
t = AsyncTransport([NodeConfig("http", "localhost", 443)])
7278
with mock.patch.object(t.node_pool.all()[0], "close") as node_close:
7379
await t.close()
7480
node_close.assert_called_with()
7581

7682

77-
@pytest.mark.asyncio
83+
@pytest.mark.anyio
7884
async def test_request_with_custom_user_agent_header():
7985
t = AsyncTransport(
8086
[NodeConfig("http", "localhost", 80)],
@@ -91,7 +97,7 @@ async def test_request_with_custom_user_agent_header():
9197
} == t.node_pool.get().calls[0][1]
9298

9399

94-
@pytest.mark.asyncio
100+
@pytest.mark.anyio
95101
async def test_body_gets_encoded_into_bytes():
96102
t = AsyncTransport([NodeConfig("http", "localhost", 80)], node_class=AsyncDummyNode)
97103

@@ -105,7 +111,7 @@ async def test_body_gets_encoded_into_bytes():
105111
assert kwargs["body"] == b'{"key":"\xe4\xbd\xa0\xe5\xa5\xbd"}'
106112

107113

108-
@pytest.mark.asyncio
114+
@pytest.mark.anyio
109115
async def test_body_bytes_get_passed_untouched():
110116
t = AsyncTransport([NodeConfig("http", "localhost", 80)], node_class=AsyncDummyNode)
111117

@@ -131,7 +137,7 @@ def test_kwargs_passed_on_to_node_pool():
131137
assert dt is t.node_pool.max_dead_node_backoff
132138

133139

134-
@pytest.mark.asyncio
140+
@pytest.mark.anyio
135141
async def test_request_will_fail_after_x_retries():
136142
t = AsyncTransport(
137143
[
@@ -154,7 +160,7 @@ async def test_request_will_fail_after_x_retries():
154160

155161

156162
@pytest.mark.parametrize("retry_on_timeout", [True, False])
157-
@pytest.mark.asyncio
163+
@pytest.mark.anyio
158164
async def test_retry_on_timeout(retry_on_timeout):
159165
t = AsyncTransport(
160166
[
@@ -189,7 +195,7 @@ async def test_retry_on_timeout(retry_on_timeout):
189195
assert len(e.value.errors) == 0
190196

191197

192-
@pytest.mark.asyncio
198+
@pytest.mark.anyio
193199
async def test_retry_on_status():
194200
t = AsyncTransport(
195201
[
@@ -233,7 +239,7 @@ async def test_retry_on_status():
233239
]
234240

235241

236-
@pytest.mark.asyncio
242+
@pytest.mark.anyio
237243
async def test_failed_connection_will_be_marked_as_dead():
238244
t = AsyncTransport(
239245
[
@@ -262,7 +268,7 @@ async def test_failed_connection_will_be_marked_as_dead():
262268
assert all(isinstance(error, ConnectionError) for error in e.value.errors)
263269

264270

265-
@pytest.mark.asyncio
271+
@pytest.mark.anyio
266272
async def test_resurrected_connection_will_be_marked_as_live_on_success():
267273
for method in ("GET", "HEAD"):
268274
t = AsyncTransport(
@@ -283,7 +289,7 @@ async def test_resurrected_connection_will_be_marked_as_live_on_success():
283289
assert 1 == len(t.node_pool._dead_nodes.queue)
284290

285291

286-
@pytest.mark.asyncio
292+
@pytest.mark.anyio
287293
async def test_mark_dead_error_doesnt_raise():
288294
t = AsyncTransport(
289295
[
@@ -303,7 +309,7 @@ async def test_mark_dead_error_doesnt_raise():
303309
mark_dead.assert_called_with(bad_node)
304310

305311

306-
@pytest.mark.asyncio
312+
@pytest.mark.anyio
307313
async def test_node_class_as_string():
308314
t = AsyncTransport([NodeConfig("http", "localhost", 80)], node_class="aiohttp")
309315
assert isinstance(t.node_pool.get(), AiohttpHttpNode)
@@ -320,7 +326,7 @@ async def test_node_class_as_string():
320326

321327

322328
@pytest.mark.parametrize(["status", "boolean"], [(200, True), (299, True)])
323-
@pytest.mark.asyncio
329+
@pytest.mark.anyio
324330
async def test_head_response_true(status, boolean):
325331
t = AsyncTransport(
326332
[NodeConfig("http", "localhost", 80, _extras={"status": status, "body": b""})],
@@ -331,7 +337,7 @@ async def test_head_response_true(status, boolean):
331337
assert data is None
332338

333339

334-
@pytest.mark.asyncio
340+
@pytest.mark.anyio
335341
async def test_head_response_false():
336342
t = AsyncTransport(
337343
[NodeConfig("http", "localhost", 80, _extras={"status": 404, "body": b""})],
@@ -353,7 +359,7 @@ async def test_head_response_false():
353359
(HttpxAsyncHttpNode, "hx"),
354360
],
355361
)
356-
@pytest.mark.asyncio
362+
@pytest.mark.anyio
357363
async def test_transport_client_meta_node_class(node_class, client_short_name):
358364
t = AsyncTransport([NodeConfig("http", "localhost", 80)], node_class=node_class)
359365
assert (
@@ -366,7 +372,7 @@ async def test_transport_client_meta_node_class(node_class, client_short_name):
366372
)
367373

368374

369-
@pytest.mark.asyncio
375+
@pytest.mark.anyio
370376
async def test_transport_default_client_meta_node_class():
371377
# Defaults to aiohttp
372378
t = AsyncTransport(
@@ -635,7 +641,7 @@ async def test_sniff_on_start_no_results_errors(sniff_callback):
635641

636642

637643
@pytest.mark.parametrize("pool_size", [1, 8])
638-
@pytest.mark.asyncio
644+
@pytest.mark.anyio
639645
async def test_multiple_tasks_test(pool_size):
640646
node_configs = [
641647
NodeConfig("http", "localhost", 80),
@@ -648,34 +654,43 @@ async def sniff_callback(*_):
648654
await asyncio.sleep(random.random())
649655
return node_configs
650656

657+
kwargs = {}
658+
if sniffio.current_async_library() == "asyncio":
659+
kwargs = {
660+
"sniff_on_start": True,
661+
"sniff_before_requests": True,
662+
"sniff_on_node_failure": True,
663+
"sniff_callback": sniff_callback,
664+
}
665+
651666
t = AsyncTransport(
652667
node_configs,
653668
retry_on_status=[500],
654669
max_retries=5,
655670
node_class=AsyncDummyNode,
656-
sniff_on_start=True,
657-
sniff_before_requests=True,
658-
sniff_on_node_failure=True,
659-
sniff_callback=sniff_callback,
671+
**kwargs,
660672
)
661673

662-
loop = asyncio.get_running_loop()
663-
start = loop.time()
674+
start = time.monotonic()
675+
676+
successful_requests = 0
664677

665678
async def run_requests():
666-
successful_requests = 0
667-
while loop.time() - start < 2:
679+
nonlocal successful_requests
680+
while time.monotonic() - start < 2:
668681
await t.perform_request("GET", "/")
669682
successful_requests += 1
670683
return successful_requests
671684

672-
tasks = [loop.create_task(run_requests()) for _ in range(pool_size * 2)]
673-
assert sum([await task for task in tasks]) >= 1000
685+
async with anyio.create_task_group() as tg:
686+
for _ in range(pool_size * 2):
687+
tg.start_soon(run_requests)
688+
assert successful_requests >= 1000
674689

675690

676-
@pytest.mark.asyncio
691+
@pytest.mark.anyio
677692
async def test_httpbin(httpbin_node_config):
678-
t = AsyncTransport([httpbin_node_config])
693+
t = AsyncTransport([httpbin_node_config], node_class=HttpxAsyncHttpNode)
679694
resp = await t.perform_request("GET", "/anything")
680695
assert resp.meta.status == 200
681696
assert isinstance(resp.body, dict)

0 commit comments

Comments
 (0)