Skip to content

Commit 6a1555e

Browse files
authored
Catch EnvironmentClient errors (#288)
1 parent a8ef49f commit 6a1555e

File tree

1 file changed

+40
-3
lines changed

1 file changed

+40
-3
lines changed

src/aviary/env_client.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import logging
23
from abc import ABC, abstractmethod
34
from collections.abc import Mapping
@@ -9,7 +10,13 @@
910

1011
from aviary.env import Environment, TaskDataset
1112
from aviary.message import Message
12-
from aviary.tools import MessagesAdapter, Tool, ToolRequestMessage, ToolsAdapter
13+
from aviary.tools import (
14+
MessagesAdapter,
15+
Tool,
16+
ToolRequestMessage,
17+
ToolResponseMessage,
18+
ToolsAdapter,
19+
)
1320

1421
logger = logging.getLogger(__name__)
1522

@@ -29,13 +36,27 @@ def __init__(
2936
request_headers: httpx._types.HeaderTypes | None = None,
3037
request_timeout: float | None = None,
3138
api_key: str | None = None,
39+
catch_http_errors: bool = False,
3240
):
41+
"""Environment client.
42+
43+
Args:
44+
reset_endpoint_url: The URL of the reset endpoint.
45+
step_endpoint_url: The URL of the step endpoint.
46+
request_params: The query parameters to send with the request.
47+
request_headers: The headers to send with the request.
48+
request_timeout: The timeout for the request. Defaults to None, which means no timeout.
49+
api_key: The API key to send with the request. Defaults to None, which means no API key.
50+
catch_http_errors: Whether to catch HTTP errors (either status or bad JSON) and return
51+
empty/placeholder messages instead of raising an error.
52+
"""
3353
self._reset_request_url = reset_endpoint_url
3454
self._step_request_url = step_endpoint_url
3555
self._request_params = request_params
3656
self._request_headers = request_headers
3757
self._request_timeout = request_timeout
3858
self._api_key = api_key
59+
self._catch_http_errors = catch_http_errors
3960

4061
async def _post(self, url: str, json: Mapping[str, Any]) -> httpx.Response:
4162
async with httpx_aiohttp.HttpxAiohttpClient() as client:
@@ -56,7 +77,13 @@ async def reset(self) -> tuple[list[Message], list[Tool]]:
5677
response = await self._post(
5778
self._reset_request_url, json=self._make_post_json(self.state)
5879
)
59-
msgs, tools = response.json()
80+
try:
81+
response.raise_for_status()
82+
msgs, tools = response.json()
83+
except (httpx.HTTPStatusError, json.JSONDecodeError):
84+
if self._catch_http_errors:
85+
return [], []
86+
raise
6087
return (
6188
MessagesAdapter.validate_python(msgs),
6289
ToolsAdapter.validate_python(tools),
@@ -70,7 +97,17 @@ async def step(
7097
json=self._make_post_json(self.state)
7198
| {"action": action.model_dump(mode="json")},
7299
)
73-
messages, reward, done, truncated = response.json()
100+
try:
101+
response.raise_for_status()
102+
messages, reward, done, truncated = response.json()
103+
except (httpx.HTTPStatusError, json.JSONDecodeError) as e:
104+
if self._catch_http_errors:
105+
messages = [
106+
ToolResponseMessage.from_call(tool_call, content=str(e))
107+
for tool_call in action.tool_calls
108+
]
109+
return messages, 0.0, True, False
110+
raise
74111
return MessagesAdapter.validate_python(messages), reward, done, truncated
75112

76113
@abstractmethod

0 commit comments

Comments
 (0)