Skip to content

Commit 81ba376

Browse files
authored
fix: initialize async http client with async method (#5296)
* fix: initialize async http client with async method Signed-off-by: Frost Ming <me@frostming.com> * fix types Signed-off-by: Frost Ming <me@frostming.com> * fix: missing attr Signed-off-by: Frost Ming <me@frostming.com> * fix: slots class Signed-off-by: Frost Ming <me@frostming.com> * Merge branch 'main' into fix/async-init
1 parent 0460e57 commit 81ba376

File tree

4 files changed

+143
-64
lines changed

4 files changed

+143
-64
lines changed

docs/source/build-with-bentoml/clients.rst

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,9 @@ After you start the ``Summarization`` Service, you can create the following clie
5656
import bentoml
5757
5858
async def async_client_operation():
59-
client = bentoml.AsyncHTTPClient('http://localhost:3000')
60-
summarized_text: str = await client.summarize(text="Your long text to summarize")
61-
print(summarized_text)
62-
63-
# Close the client to release resources
64-
await client.close()
59+
async with bentoml.AsyncHTTPClient('http://localhost:3000') as client:
60+
summarized_text: str = await client.summarize(text="Your long text to summarize")
61+
print(summarized_text)
6562
6663
asyncio.run(async_client_operation())
6764
@@ -357,12 +354,10 @@ You can add streaming logic to a BentoML client, which is especially useful when
357354
358355
import bentoml
359356
360-
client = bentoml.AsyncHTTPClient("http://localhost:3000")
361-
async for data_chunk in client.stream_data():
362-
# Process each chunk of data as it arrives
363-
await process_data_async(data_chunk)
364-
365-
await client.close()
357+
async with bentoml.AsyncHTTPClient("http://localhost:3000") as client:
358+
async for data_chunk in client.stream_data():
359+
# Process each chunk of data as it arrives
360+
await process_data_async(data_chunk)
366361
367362
async def process_data_async(data_chunk):
368363
# Add processing logic

src/_bentoml_impl/client/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ class ClientEndpoint:
3636
class AbstractClient(abc.ABC):
3737
endpoints: dict[str, ClientEndpoint]
3838

39-
def __init__(self) -> None:
39+
def _setup_endpoints(self) -> None:
40+
self._setup_done = True
4041
for name in self.endpoints:
4142
if name == "__call__":
4243
# __call__ must be set on the class

src/_bentoml_impl/client/http.py

Lines changed: 130 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -45,35 +45,39 @@
4545

4646
from ..serde import Serde
4747

48-
T = t.TypeVar("T", bound="HTTPClient[t.Any]")
49-
A = t.TypeVar("A")
48+
T = t.TypeVar("T")
49+
AnyClient = t.TypeVar("AnyClient", httpx.Client, httpx.AsyncClient)
50+
5051
C = t.TypeVar("C", httpx.Client, httpx.AsyncClient)
51-
AnyClient = t.TypeVar("AnyClient", httpx.Client, httpx.AsyncClient)
52+
5253
logger = logging.getLogger("bentoml.io")
5354
MAX_RETRIES = 3
5455

5556

56-
def to_async_iterable(iterable: t.Iterable[A]) -> t.AsyncIterable[A]:
57-
async def _gen() -> t.AsyncIterator[A]:
57+
def to_async_iterable(iterable: t.Iterable[T]) -> t.AsyncIterable[T]:
58+
async def _gen() -> t.AsyncIterator[T]:
5859
for item in iterable:
5960
yield item
6061

6162
return _gen()
6263

6364

64-
@attr.define
65+
@attr.define(slots=False)
6566
class HTTPClient(AbstractClient, t.Generic[C]):
66-
client_cls: t.ClassVar[type[httpx.Client] | type[httpx.AsyncClient]]
67+
client_cls: t.ClassVar[type[C]] # type: ignore
6768

6869
url: str
6970
endpoints: dict[str, ClientEndpoint] = attr.field(factory=dict)
7071
media_type: str = "application/json"
7172
timeout: float = 30
7273
default_headers: dict[str, str] = attr.field(factory=dict)
7374
app: ASGIApp | None = None
75+
server_ready_timeout: float | None = None
76+
service: Service[t.Any] | None = None
7477

7578
_opened_files: list[io.BufferedReader] = attr.field(init=False, factory=list)
7679
_temp_dir: tempfile.TemporaryDirectory[str] = attr.field(init=False)
80+
_setup_done: bool = attr.field(init=False, default=False)
7781

7882
@staticmethod
7983
def _make_client(
@@ -174,30 +178,9 @@ def __init__(
174178
default_headers=default_headers,
175179
timeout=timeout,
176180
app=app,
181+
server_ready_timeout=server_ready_timeout,
182+
service=service,
177183
)
178-
if app is None and (server_ready_timeout is None or server_ready_timeout > 0):
179-
self.wait_until_server_ready(server_ready_timeout)
180-
if service is None:
181-
schema_url = urljoin(url, "/schema.json")
182-
183-
with self._make_client(
184-
httpx.Client, url, default_headers, timeout, app=app
185-
) as client:
186-
resp = client.get("/schema.json")
187-
188-
if resp.is_error:
189-
raise BentoMLException(f"Failed to fetch schema from {schema_url}")
190-
for route in resp.json()["routes"]:
191-
self.endpoints[route["name"]] = ClientEndpoint(
192-
name=route["name"],
193-
route=route["route"],
194-
input=route["input"],
195-
output=route["output"],
196-
doc=route.get("doc"),
197-
stream_output=route["output"].get("is_stream", False),
198-
is_task=route.get("is_task", False),
199-
)
200-
super().__init__()
201184

202185
@cached_property
203186
def client(self) -> C:
@@ -326,22 +309,6 @@ def _build_request(
326309
headers=headers,
327310
)
328311

329-
def wait_until_server_ready(self, timeout: int | None = None) -> None:
330-
if timeout is None:
331-
timeout = self.timeout
332-
with self._make_client(
333-
httpx.Client, self.url, self.default_headers, timeout
334-
) as client:
335-
start = time.monotonic()
336-
while time.monotonic() - start < timeout:
337-
try:
338-
resp = client.get("/readyz")
339-
if resp.status_code == 200:
340-
return
341-
except (httpx.TimeoutException, httpx.ConnectError):
342-
pass
343-
raise ServiceUnavailable(f"Server is not ready after {timeout} seconds")
344-
345312
def _get_file(self, value: t.Any) -> str | tuple[str, t.IO[bytes], str | None]:
346313
if isinstance(value, str) and not is_http_url(value):
347314
value = pathlib.Path(value)
@@ -457,7 +424,69 @@ class SyncHTTPClient(HTTPClient[httpx.Client]):
457424

458425
client_cls = httpx.Client
459426

460-
def __enter__(self: T) -> T:
427+
def __init__(
428+
self,
429+
url: str,
430+
*,
431+
media_type: str = "application/json",
432+
service: Service[t.Any] | None = None,
433+
server_ready_timeout: float | None = None,
434+
token: str | None = None,
435+
timeout: float = 30,
436+
app: ASGIApp | None = None,
437+
):
438+
super().__init__(
439+
url,
440+
media_type=media_type,
441+
service=service,
442+
server_ready_timeout=server_ready_timeout,
443+
token=token,
444+
timeout=timeout,
445+
app=app,
446+
)
447+
self._setup()
448+
449+
def _setup(self) -> None:
450+
if self._setup_done:
451+
return
452+
453+
if self.app is None and (
454+
self.server_ready_timeout is None or self.server_ready_timeout > 0
455+
):
456+
self.wait_until_server_ready(self.server_ready_timeout)
457+
if self.service is None:
458+
schema_url = urljoin(self.url, "/schema.json")
459+
460+
resp = self.client.get("/schema.json")
461+
462+
if resp.is_error:
463+
raise BentoMLException(f"Failed to fetch schema from {schema_url}")
464+
for route in resp.json()["routes"]:
465+
self.endpoints[route["name"]] = ClientEndpoint(
466+
name=route["name"],
467+
route=route["route"],
468+
input=route["input"],
469+
output=route["output"],
470+
doc=route.get("doc"),
471+
stream_output=route["output"].get("is_stream", False),
472+
is_task=route.get("is_task", False),
473+
)
474+
self._setup_endpoints()
475+
476+
def wait_until_server_ready(self, timeout: float | None = None) -> None:
477+
if timeout is None:
478+
timeout = self.timeout
479+
start = time.monotonic()
480+
while time.monotonic() - start < timeout:
481+
try:
482+
resp = self.client.get("/readyz")
483+
if resp.status_code == 200:
484+
return
485+
except (httpx.TimeoutException, httpx.ConnectError):
486+
pass
487+
raise ServiceUnavailable(f"Server is not ready after {timeout} seconds")
488+
489+
def __enter__(self) -> t.Self:
461490
return self
462491

463492
def __exit__(self, exc_type: t.Any, exc: t.Any, tb: t.Any) -> None:
@@ -470,7 +499,7 @@ def is_ready(self, timeout: int | None = None) -> bool:
470499
)
471500
return resp.status_code == 200
472501
except httpx.TimeoutException:
473-
logger.warn("Timed out waiting for runner to be ready")
502+
logger.warning("Timed out waiting for runner to be ready")
474503
return False
475504

476505
def close(self) -> None:
@@ -629,14 +658,54 @@ class AsyncHTTPClient(HTTPClient[httpx.AsyncClient]):
629658

630659
client_cls = httpx.AsyncClient
631660

661+
async def _setup(self) -> None:
662+
if self._setup_done:
663+
return
664+
665+
if self.app is None and (
666+
self.server_ready_timeout is None or self.server_ready_timeout > 0
667+
):
668+
await self.wait_until_server_ready(self.server_ready_timeout)
669+
if self.service is None:
670+
schema_url = urljoin(self.url, "/schema.json")
671+
672+
resp = await self.client.get("/schema.json")
673+
674+
if resp.is_error:
675+
raise BentoMLException(f"Failed to fetch schema from {schema_url}")
676+
for route in resp.json()["routes"]:
677+
self.endpoints[route["name"]] = ClientEndpoint(
678+
name=route["name"],
679+
route=route["route"],
680+
input=route["input"],
681+
output=route["output"],
682+
doc=route.get("doc"),
683+
stream_output=route["output"].get("is_stream", False),
684+
is_task=route.get("is_task", False),
685+
)
686+
self._setup_endpoints()
687+
688+
async def wait_until_server_ready(self, timeout: float | None = None) -> None:
689+
if timeout is None:
690+
timeout = self.timeout
691+
start = time.monotonic()
692+
while time.monotonic() - start < timeout:
693+
try:
694+
resp = await self.client.get("/readyz")
695+
if resp.status_code == 200:
696+
return
697+
except (httpx.TimeoutException, httpx.ConnectError):
698+
pass
699+
raise ServiceUnavailable(f"Server is not ready after {timeout} seconds")
700+
632701
async def is_ready(self, timeout: int | None = None) -> bool:
633702
try:
634703
resp = await self.client.get(
635704
"/readyz", timeout=timeout or httpx.USE_CLIENT_DEFAULT
636705
)
637706
return resp.status_code == 200
638707
except httpx.TimeoutException:
639-
logger.warn("Timed out waiting for runner to be ready")
708+
logger.warning("Timed out waiting for runner to be ready")
640709
return False
641710

642711
async def _get_stream(
@@ -647,7 +716,18 @@ async def _get_stream(
647716
async for data in resp:
648717
yield data
649718

650-
async def __aenter__(self: T) -> T:
719+
def __getattr__(self, name: str) -> t.Any:
720+
if not self._setup_done:
721+
raise RuntimeError(
722+
"Client is not set up yet, please use it as an async context manager"
723+
)
724+
else:
725+
raise AttributeError(
726+
f"'{self.__class__.__name__}' object has no attribute '{name}'"
727+
)
728+
729+
async def __aenter__(self) -> t.Self:
730+
await self._setup()
651731
return self
652732

653733
async def __aexit__(self, *args: t.Any) -> None:

src/_bentoml_impl/client/proxy.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,16 @@ def __init__(
5959
server_ready_timeout=0,
6060
app=app,
6161
)
62+
# Setup async client with the same endpoints
63+
self._async.endpoints = self._sync.endpoints
64+
self._async._setup_endpoints()
6265
if service is not None:
6366
self._inner = service.inner
6467
self.endpoints = self._async.endpoints
6568
else:
6669
self.endpoints = {}
6770
self._inner = None
68-
super().__init__()
71+
self._setup_endpoints()
6972

7073
@property
7174
def to_async(self) -> AsyncHTTPClient:

0 commit comments

Comments
 (0)