|
11 | 11 | from agentlab.agents.agent_args import AgentArgs |
12 | 12 | from agentlab.llm.llm_utils import image_to_png_base64_url |
13 | 13 | from agentlab.llm.response_api import ( |
| 14 | + BaseModelArgs, |
14 | 15 | ClaudeResponseModelArgs, |
15 | 16 | MessageBuilder, |
16 | 17 | OpenAIChatModelArgs, |
17 | 18 | OpenAIResponseModelArgs, |
18 | 19 | OpenRouterModelArgs, |
19 | 20 | ResponseLLMOutput, |
| 21 | + VLLMModelArgs, |
20 | 22 | ) |
21 | 23 | from agentlab.llm.tracking import cost_tracker_decorator |
22 | 24 | from browsergym.core.observation import extract_screenshot |
@@ -264,7 +266,7 @@ def get_openrouter_tool_use_agent( |
264 | 266 | tag_screenshot=True, |
265 | 267 | use_raw_page_output=True, |
266 | 268 | ) -> ToolUseAgentArgs: |
267 | | - #To Do : Check if OpenRouter endpoint specific args are working |
| 269 | + # To Do : Check if OpenRouter endpoint specific args are working |
268 | 270 | if not supports_tool_calling(model_name): |
269 | 271 | raise ValueError(f"Model {model_name} does not support tool calling.") |
270 | 272 |
|
@@ -301,13 +303,43 @@ def get_openrouter_tool_use_agent( |
301 | 303 |
|
302 | 304 |
|
303 | 305 | OAI_CHAT_TOOl_AGENT = ToolUseAgentArgs( |
304 | | - model_args=OpenAIChatModelArgs(model_name="gpt-4o-2024-08-06"), |
305 | | - use_first_obs=False, |
306 | | - tag_screenshot=False, |
307 | | - use_raw_page_output=True, |
| 306 | + model_args=OpenAIChatModelArgs(model_name="gpt-4o-2024-08-06") |
308 | 307 | ) |
309 | 308 |
|
310 | 309 |
|
| 310 | +PROVIDER_FACTORY_MAP = { |
| 311 | + "openai": {"chatcompletion": OpenAIChatModelArgs, "response": OpenAIResponseModelArgs}, |
| 312 | + "openrouter": OpenRouterModelArgs, |
| 313 | + "vllm": VLLMModelArgs, |
| 314 | + "antrophic": ClaudeResponseModelArgs, |
| 315 | +} |
| 316 | + |
| 317 | + |
| 318 | +def get_tool_use_agent( |
| 319 | + api_provider: str, |
| 320 | + model_args: "BaseModelArgs", |
| 321 | + tool_use_agent_args: dict = None, |
| 322 | + api_provider_spec=None, |
| 323 | +) -> ToolUseAgentArgs: |
| 324 | + |
| 325 | + if api_provider == "openai": |
| 326 | + assert ( |
| 327 | + api_provider_spec is not None |
| 328 | + ), "Endpoint specification is required for OpenAI provider. Choose between 'chatcompletion' and 'response'." |
| 329 | + |
| 330 | + model_args_factory = ( |
| 331 | + PROVIDER_FACTORY_MAP[api_provider] |
| 332 | + if api_provider_spec is None |
| 333 | + else PROVIDER_FACTORY_MAP[api_provider][api_provider_spec] |
| 334 | + ) |
| 335 | + |
| 336 | + # Create the agent with model arguments from the factory |
| 337 | + agent = ToolUseAgentArgs( |
| 338 | + model_args=model_args_factory(**model_args), **(tool_use_agent_args or {}) |
| 339 | + ) |
| 340 | + return agent |
| 341 | + |
| 342 | + |
311 | 343 | ## We have three providers that we want to support. |
312 | 344 | # Anthropic |
313 | 345 | # OpenAI |
|
0 commit comments