1+ import json
12import logging
23from abc import ABC , abstractmethod
34from collections .abc import Mapping
910
1011from aviary .env import Environment , TaskDataset
1112from aviary .message import Message
12- from aviary .tools import MessagesAdapter , Tool , ToolRequestMessage , ToolsAdapter
13+ from aviary .tools import (
14+ MessagesAdapter ,
15+ Tool ,
16+ ToolRequestMessage ,
17+ ToolResponseMessage ,
18+ ToolsAdapter ,
19+ )
1320
1421logger = logging .getLogger (__name__ )
1522
@@ -29,13 +36,27 @@ def __init__(
2936 request_headers : httpx ._types .HeaderTypes | None = None ,
3037 request_timeout : float | None = None ,
3138 api_key : str | None = None ,
39+ catch_http_errors : bool = False ,
3240 ):
41+ """Environment client.
42+
43+ Args:
44+ reset_endpoint_url: The URL of the reset endpoint.
45+ step_endpoint_url: The URL of the step endpoint.
46+ request_params: The query parameters to send with the request.
47+ request_headers: The headers to send with the request.
48+ request_timeout: The timeout for the request. Defaults to None, which means no timeout.
49+ api_key: The API key to send with the request. Defaults to None, which means no API key.
50+ catch_http_errors: Whether to catch HTTP errors (either status or bad JSON) and return
51+ empty/placeholder messages instead of raising an error.
52+ """
3353 self ._reset_request_url = reset_endpoint_url
3454 self ._step_request_url = step_endpoint_url
3555 self ._request_params = request_params
3656 self ._request_headers = request_headers
3757 self ._request_timeout = request_timeout
3858 self ._api_key = api_key
59+ self ._catch_http_errors = catch_http_errors
3960
4061 async def _post (self , url : str , json : Mapping [str , Any ]) -> httpx .Response :
4162 async with httpx_aiohttp .HttpxAiohttpClient () as client :
@@ -56,7 +77,13 @@ async def reset(self) -> tuple[list[Message], list[Tool]]:
5677 response = await self ._post (
5778 self ._reset_request_url , json = self ._make_post_json (self .state )
5879 )
59- msgs , tools = response .json ()
80+ try :
81+ response .raise_for_status ()
82+ msgs , tools = response .json ()
83+ except (httpx .HTTPStatusError , json .JSONDecodeError ):
84+ if self ._catch_http_errors :
85+ return [], []
86+ raise
6087 return (
6188 MessagesAdapter .validate_python (msgs ),
6289 ToolsAdapter .validate_python (tools ),
@@ -70,7 +97,17 @@ async def step(
7097 json = self ._make_post_json (self .state )
7198 | {"action" : action .model_dump (mode = "json" )},
7299 )
73- messages , reward , done , truncated = response .json ()
100+ try :
101+ response .raise_for_status ()
102+ messages , reward , done , truncated = response .json ()
103+ except (httpx .HTTPStatusError , json .JSONDecodeError ) as e :
104+ if self ._catch_http_errors :
105+ messages = [
106+ ToolResponseMessage .from_call (tool_call , content = str (e ))
107+ for tool_call in action .tool_calls
108+ ]
109+ return messages , 0.0 , True , False
110+ raise
74111 return MessagesAdapter .validate_python (messages ), reward , done , truncated
75112
76113 @abstractmethod
0 commit comments