Skip to content

Commit a8ef49f

Browse files
authored
Migrating httpx.AsyncClient to httpx_aiohttp.HttpxAiohttpClient (#287)
1 parent 2178e44 commit a8ef49f

File tree

7 files changed

+77
-16
lines changed

7 files changed

+77
-16
lines changed

packages/hotpotqa/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ dependencies = [
2222
"datasets>=2.15,<4", # Lower pin for https://github.com/huggingface/datasets/pull/6404, upper pin for https://huggingface.co/datasets/hotpotqa/hotpot_qa/discussions/8
2323
"fhaviary",
2424
"httpx",
25+
"httpx-aiohttp",
2526
"pydantic~=2.0",
2627
"tenacity",
2728
]

packages/hotpotqa/src/aviary/envs/hotpotqa/env.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from uuid import UUID
2525

2626
import httpx
27+
import httpx_aiohttp
2728
from bs4 import BeautifulSoup
2829
from datasets import load_dataset
2930
from pydantic import BaseModel, ConfigDict, Field, field_validator
@@ -424,7 +425,7 @@ async def search(self, entity: str) -> str:
424425
response_text = self.wiki_cache[search_entity]
425426
except KeyError:
426427
# follow_redirects=True because wikipedia frequently redirects to the correct page
427-
async with httpx.AsyncClient(
428+
async with httpx_aiohttp.HttpxAiohttpClient(
428429
follow_redirects=True, proxy=self.proxy
429430
) as client:
430431
response = await fetch_with_retry(

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ classifiers = [
2424
dependencies = [
2525
"docstring_parser>=0.16", # Pin for description addition
2626
"httpx",
27+
"httpx-aiohttp",
2728
"pydantic~=2.0",
2829
]
2930
description = "Gymnasium framework for training language model agents on constructive tasks"
@@ -46,7 +47,7 @@ dev = [
4647
"fhaviary[image,llm,server,typing,xml]",
4748
"ipython>=8", # Pin to keep recent
4849
"jupyter>=1.0.0", # For running notebooks
49-
"litellm>=1.65.5,<1.71", # Lower pin for sending tool schemae title in completions, upper pin for VCR cassette breaks (https://github.com/BerriAI/litellm/issues/11724)
50+
"litellm>=1.71", # Lower pin for aiohttp transport adoption
5051
"mypy>=1.8", # Pin for mutable-override
5152
"numpy>=1", # Pin to keep recent
5253
"pre-commit>=3.4", # Pin to keep recent

src/aviary/env_client.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any, Generic, TypeVar, cast
55

66
import httpx
7+
import httpx_aiohttp
78
from pydantic import BaseModel, Field
89

910
from aviary.env import Environment, TaskDataset
@@ -14,7 +15,9 @@
1415

1516
# Not sure why, but mypy complains if we use the TEnvState in aviary.env, so redefine here
1617
TEnvState = TypeVar("TEnvState")
17-
TClient = TypeVar("TClient", httpx.Client, httpx.AsyncClient)
18+
TClient = TypeVar(
19+
"TClient", httpx.Client, httpx.AsyncClient, httpx_aiohttp.HttpxAiohttpClient
20+
)
1821

1922

2023
class EnvironmentClient(Environment[TEnvState], ABC, Generic[TEnvState]):
@@ -35,7 +38,7 @@ def __init__(
3538
self._api_key = api_key
3639

3740
async def _post(self, url: str, json: Mapping[str, Any]) -> httpx.Response:
38-
async with httpx.AsyncClient() as client:
41+
async with httpx_aiohttp.HttpxAiohttpClient() as client:
3942
headers = httpx.Headers(self._request_headers)
4043
if self._api_key:
4144
headers["X-API-Key"] = self._api_key
@@ -142,7 +145,7 @@ def __init__(
142145

143146
def _get_http_client(
144147
self,
145-
client_class: type[TClient] = httpx.AsyncClient, # type: ignore[assignment]
148+
client_class: type[TClient] = httpx_aiohttp.HttpxAiohttpClient, # type: ignore[assignment]
146149
) -> TClient:
147150
headers = {}
148151
if self.api_key:

tests/conftest.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1+
from collections.abc import AsyncIterator
12
from typing import TYPE_CHECKING, Any
23
from urllib.parse import parse_qs, urlencode, urlparse
34

5+
import httpx_aiohttp
6+
import litellm.llms.custom_httpx.aiohttp_transport
47
import pytest
5-
import vcr.request
8+
import vcr.stubs.httpx_stubs
69

710
from . import CASSETTES_DIR
811

912
if TYPE_CHECKING:
10-
import vcr.request
13+
import vcr.request # noqa: TC004
1114

1215
from aviary.core import DummyEnv
1316

@@ -56,3 +59,38 @@ def fixture_vcr_config() -> dict[str, Any]:
5659
"allow_playback_repeats": True,
5760
"cassette_library_dir": str(CASSETTES_DIR),
5861
}
62+
63+
64+
class PreReadCompatibleAiohttpResponseStream(
65+
httpx_aiohttp.transport.AiohttpResponseStream
66+
):
67+
"""aiohttp-backed response stream that works if the response was pre-read."""
68+
69+
async def __aiter__(self) -> AsyncIterator[bytes]:
70+
with httpx_aiohttp.transport.map_aiohttp_exceptions():
71+
if self._aiohttp_response._body is not None:
72+
# Happens if some intermediary called `await _aiohttp_response.read()`
73+
# TODO: take into account chunk size
74+
yield self._aiohttp_response._body
75+
else:
76+
async for chunk in self._aiohttp_response.content.iter_chunked(
77+
self.CHUNK_SIZE
78+
):
79+
yield chunk
80+
81+
82+
async def _async_vcr_send(cassette, real_send, *args, **kwargs): # noqa: ARG001
83+
"""VCR send that only sends, not possibly recording or playing back responses."""
84+
return await real_send(*args, **kwargs)
85+
86+
87+
# Permanently patch the original response stream,
88+
# to work around https://github.com/karpetrosyan/httpx-aiohttp/issues/23
89+
# and https://github.com/BerriAI/litellm/issues/11724
90+
httpx_aiohttp.transport.AiohttpResponseStream = ( # type: ignore[misc]
91+
litellm.llms.custom_httpx.aiohttp_transport.AiohttpResponseStream # type: ignore[misc]
92+
) = PreReadCompatibleAiohttpResponseStream # type: ignore[assignment]
93+
94+
# Permanently patch vcrpy's async VCR recording functionality,
95+
# to work around https://github.com/kevin1024/vcrpy/issues/944
96+
vcr.stubs.httpx_stubs._async_vcr_send = _async_vcr_send

tests/test_envs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,9 @@ async def test_dummyenv_using_empty_params(self, dummy_env: DummyEnv) -> None:
562562
def server_async_client() -> AsyncClient:
563563
dataset = TaskDataset.from_name("dummy")
564564
server = TaskDatasetServer[DummyEnv](dataset)
565+
# Use httpx.AsyncClient over httpx_aiohttp.HttpxAiohttpClient in tests here,
566+
# as httpx_aiohttp.AiohttpTransport doesn't support an app argument
567+
# as of httpx-aiohttp==0.1.8
565568
return AsyncClient(transport=ASGITransport(app=server.app), base_url="http://test")
566569

567570

uv.lock

Lines changed: 23 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)