Skip to content

Commit 97a39cc

Browse files
added vllm-support-for-tool-use-agent
1 parent fe05d75 commit 97a39cc

File tree

2 files changed

+192
-46
lines changed

2 files changed

+192
-46
lines changed

src/agentlab/agents/tool_use_agent/agent.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@
1111
from agentlab.agents.agent_args import AgentArgs
1212
from agentlab.llm.llm_utils import image_to_png_base64_url
1313
from agentlab.llm.response_api import (
14+
BaseModelArgs,
1415
ClaudeResponseModelArgs,
1516
MessageBuilder,
1617
OpenAIChatModelArgs,
1718
OpenAIResponseModelArgs,
1819
OpenRouterModelArgs,
1920
ResponseLLMOutput,
21+
VLLMModelArgs,
2022
)
2123
from agentlab.llm.tracking import cost_tracker_decorator
2224
from browsergym.core.observation import extract_screenshot
@@ -264,7 +266,7 @@ def get_openrouter_tool_use_agent(
264266
tag_screenshot=True,
265267
use_raw_page_output=True,
266268
) -> ToolUseAgentArgs:
267-
#To Do : Check if OpenRouter endpoint specific args are working
269+
# To Do : Check if OpenRouter endpoint specific args are working
268270
if not supports_tool_calling(model_name):
269271
raise ValueError(f"Model {model_name} does not support tool calling.")
270272

@@ -301,13 +303,43 @@ def get_openrouter_tool_use_agent(
301303

302304

303305
OAI_CHAT_TOOl_AGENT = ToolUseAgentArgs(
304-
model_args=OpenAIChatModelArgs(model_name="gpt-4o-2024-08-06"),
305-
use_first_obs=False,
306-
tag_screenshot=False,
307-
use_raw_page_output=True,
306+
model_args=OpenAIChatModelArgs(model_name="gpt-4o-2024-08-06")
308307
)
309308

310309

310+
PROVIDER_FACTORY_MAP = {
311+
"openai": {"chatcompletion": OpenAIChatModelArgs, "response": OpenAIResponseModelArgs},
312+
"openrouter": OpenRouterModelArgs,
313+
"vllm": VLLMModelArgs,
314+
"antrophic": ClaudeResponseModelArgs,
315+
}
316+
317+
318+
def get_tool_use_agent(
319+
api_provider: str,
320+
model_args: "BaseModelArgs",
321+
tool_use_agent_args: dict = None,
322+
api_provider_spec=None,
323+
) -> ToolUseAgentArgs:
324+
325+
if api_provider == "openai":
326+
assert (
327+
api_provider_spec is not None
328+
), "Endpoint specification is required for OpenAI provider. Choose between 'chatcompletion' and 'response'."
329+
330+
model_args_factory = (
331+
PROVIDER_FACTORY_MAP[api_provider]
332+
if api_provider_spec is None
333+
else PROVIDER_FACTORY_MAP[api_provider][api_provider_spec]
334+
)
335+
336+
# Create the agent with model arguments from the factory
337+
agent = ToolUseAgentArgs(
338+
model_args=model_args_factory(**model_args), **(tool_use_agent_args or {})
339+
)
340+
return agent
341+
342+
311343
## We have three providers that we want to support.
312344
# Anthropic
313345
# OpenAI

0 commit comments

Comments
 (0)