Skip to content

Commit 7d8a08c

Browse files
Moving some functions to llm utils.py
1 parent 97a39cc commit 7d8a08c

File tree

3 files changed

+146
-167
lines changed

3 files changed

+146
-167
lines changed

src/agentlab/agents/tool_use_agent/agent.py

Lines changed: 32 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -203,93 +203,52 @@ def initalize_messages(self, obs: Any) -> None:
203203
)
204204

205205

206-
def supports_tool_calling(model_name: str) -> bool:
207-
"""
208-
Check if the model supports tool calling.
209206

210-
Args:
211-
model_name (str): The name of the model.
212207

213-
Returns:
214-
bool: True if the model supports tool calling, False otherwise.
215-
"""
216-
import os
208+
# def get_openrouter_model(model_name: str, **open_router_args) -> OpenRouterModelArgs:
209+
# default_model_args = {
210+
# "max_total_tokens": 200_000,
211+
# "max_input_tokens": 180_000,
212+
# "max_new_tokens": 2_000,
213+
# "temperature": 0.1,
214+
# "vision_support": True,
215+
# }
216+
# merged_args = {**default_model_args, **open_router_args}
217217

218-
import openai
218+
# return OpenRouterModelArgs(model_name=model_name, **merged_args)
219219

220-
client = openai.Client(
221-
api_key=os.getenv("OPENROUTER_API_KEY"), base_url="https://openrouter.ai/api/v1"
222-
)
223-
try:
224-
response = client.chat.completions.create(
225-
model=model_name,
226-
messages=[{"role": "user", "content": "Call the test tool"}],
227-
tools=[
228-
{
229-
"type": "function",
230-
"function": {
231-
"name": "dummy_tool",
232-
"description": "Just a test tool",
233-
"parameters": {
234-
"type": "object",
235-
"properties": {},
236-
},
237-
},
238-
}
239-
],
240-
tool_choice="required",
241-
)
242-
response = response.to_dict()
243-
return "tool_calls" in response["choices"][0]["message"]
244-
except Exception as e:
245-
print(f"Model '{model_name}' error: {e}")
246-
return False
247-
248-
249-
def get_openrouter_model(model_name: str, **open_router_args) -> OpenRouterModelArgs:
250-
default_model_args = {
251-
"max_total_tokens": 200_000,
252-
"max_input_tokens": 180_000,
253-
"max_new_tokens": 2_000,
254-
"temperature": 0.1,
255-
"vision_support": True,
256-
}
257-
merged_args = {**default_model_args, **open_router_args}
258-
259-
return OpenRouterModelArgs(model_name=model_name, **merged_args)
260-
261-
262-
def get_openrouter_tool_use_agent(
263-
model_name: str,
264-
model_args: dict = {},
265-
use_first_obs=True,
266-
tag_screenshot=True,
267-
use_raw_page_output=True,
268-
) -> ToolUseAgentArgs:
269-
# To Do : Check if OpenRouter endpoint specific args are working
270-
if not supports_tool_calling(model_name):
271-
raise ValueError(f"Model {model_name} does not support tool calling.")
272220

273-
model_args = get_openrouter_model(model_name, **model_args)
221+
# def get_openrouter_tool_use_agent(
222+
# model_name: str,
223+
# model_args: dict = {},
224+
# use_first_obs=True,
225+
# tag_screenshot=True,
226+
# use_raw_page_output=True,
227+
# ) -> ToolUseAgentArgs:
228+
# # To Do : Check if OpenRouter endpoint specific args are working
229+
# if not supports_tool_calling(model_name):
230+
# raise ValueError(f"Model {model_name} does not support tool calling.")
274231

275-
return ToolUseAgentArgs(
276-
model_args=model_args,
277-
use_first_obs=use_first_obs,
278-
tag_screenshot=tag_screenshot,
279-
use_raw_page_output=use_raw_page_output,
280-
)
232+
# model_args = get_openrouter_model(model_name, **model_args)
233+
234+
# return ToolUseAgentArgs(
235+
# model_args=model_args,
236+
# use_first_obs=use_first_obs,
237+
# tag_screenshot=tag_screenshot,
238+
# use_raw_page_output=use_raw_page_output,
239+
# )
281240

282241

283-
OPENROUTER_MODEL = get_openrouter_tool_use_agent("google/gemini-2.5-pro-preview")
242+
# OPENROUTER_MODEL = get_openrouter_tool_use_agent("google/gemini-2.5-pro-preview")
284243

285244

286245
AGENT_CONFIG = ToolUseAgentArgs(
287246
model_args=CLAUDE_MODEL_CONFIG,
288247
)
289248

290-
MT_TOOL_USE_AGENT = ToolUseAgentArgs(
291-
model_args=OPENROUTER_MODEL,
292-
)
249+
# MT_TOOL_USE_AGENT = ToolUseAgentArgs(
250+
# model_args=OPENROUTER_MODEL,
251+
# )
293252
CHATAPI_AGENT_CONFIG = ToolUseAgentArgs(
294253
model_args=OpenAIChatModelArgs(
295254
model_name="gpt-4o-2024-11-20",

src/agentlab/llm/llm_utils.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from warnings import warn
1313

1414
import numpy as np
15+
import openai
1516
import tiktoken
1617
import yaml
1718
from langchain.schema import BaseMessage
@@ -90,6 +91,102 @@ def retry(
9091
raise ParseError(f"Could not parse a valid value after {n_retry} retries.")
9192

9293

94+
def call_with_retries(client_function, api_params, max_retries=5):
95+
"""
96+
Makes a API call with retries for transient failures,
97+
rate limiting, and invalid or error-containing responses.
98+
99+
Args:
100+
client_function (Callable): Function to call the API (e.g., openai.ChatCompletion.create).
101+
api_params (dict): Parameters to pass to the API function.
102+
max_retries (int): Maximum number of retry attempts.
103+
104+
Returns:
105+
response: Valid API response object.
106+
"""
107+
for attempt in range(1, max_retries + 1):
108+
try:
109+
response = client_function(**api_params)
110+
111+
# Check for explicit error field in response object
112+
if getattr(response, "error", None):
113+
logging.warning(
114+
f"[Attempt {attempt}] API returned error: {response.error}. Retrying..."
115+
)
116+
continue
117+
118+
# Check for valid response with choices
119+
if hasattr(response, "choices") and response.choices:
120+
logging.info(f"[Attempt {attempt}] API call succeeded.")
121+
return response
122+
123+
logging.warning(
124+
f"[Attempt {attempt}] API returned empty or malformed response. Retrying..."
125+
)
126+
127+
except openai.APIError as e:
128+
logging.error(f"[Attempt {attempt}] APIError: {e}")
129+
if e.http_status == 429:
130+
logging.warning("Rate limit exceeded. Retrying...")
131+
elif e.http_status >= 500:
132+
logging.warning("Server error encountered. Retrying...")
133+
else:
134+
logging.error("Non-retriable API error occurred.")
135+
raise
136+
137+
except Exception as e:
138+
logging.exception(f"[Attempt {attempt}] Unexpected exception occurred: {e}")
139+
raise
140+
141+
logging.error("Exceeded maximum retry attempts. API call failed.")
142+
raise RuntimeError("API call failed after maximum retries.")
143+
144+
145+
def supports_tool_calling_for_openrouter(
146+
model_name: str,
147+
) -> bool:
148+
"""
149+
Check if the openrouter model supports tool calling.
150+
151+
Args:
152+
model_name (str): The name of the model.
153+
154+
Returns:
155+
bool: True if the model supports tool calling, False otherwise.
156+
"""
157+
import os
158+
159+
import openai
160+
161+
client = openai.Client(
162+
api_key=os.getenv("OPENROUTER_API_KEY"), base_url="https://openrouter.ai/api/v1"
163+
)
164+
try:
165+
response = client.chat.completions.create(
166+
model=model_name,
167+
messages=[{"role": "user", "content": "Call the test tool"}],
168+
tools=[
169+
{
170+
"type": "function",
171+
"function": {
172+
"name": "dummy_tool",
173+
"description": "Just a test tool",
174+
"parameters": {
175+
"type": "object",
176+
"properties": {},
177+
},
178+
},
179+
}
180+
],
181+
tool_choice="required",
182+
)
183+
response = response.to_dict()
184+
return "tool_calls" in response["choices"][0]["message"]
185+
except Exception as e:
186+
print(f"Model '{model_name}' error: {e}")
187+
return False
188+
189+
93190
def retry_multiple(
94191
chat: "ChatModel",
95192
messages: "Discussion",

0 commit comments

Comments
 (0)