Skip to content

Commit 20502a8

Browse files
committed
move action and tool classes to actions module
1 parent e2cd4b9 commit 20502a8

File tree

9 files changed

+84
-82
lines changed

9 files changed

+84
-82
lines changed

src/agentlab/actions.py

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,81 @@
11
import json
22
import logging
3+
from typing import Any, Callable, Literal
34

45
from bgym import AbstractActionSet
6+
from langchain_core.utils.function_calling import convert_to_openai_tool
7+
from pydantic import BaseModel
58

6-
from agentlab.backends.browser.base import FunctionCall, ToolCallAction, ToolSpec
79
from agentlab.llm.llm_utils import parse_html_tags_raise
810

911
logger = logging.getLogger(__name__)
1012

1113

14+
class FunctionSpec(BaseModel):
15+
"""
16+
A class representing the specification of a function.
17+
18+
Attributes:
19+
name (str): The name of the function.
20+
description (str): A brief description of the function.
21+
parameters (dict): A dictionary containing the parameters of the function.
22+
"""
23+
24+
name: str
25+
description: str
26+
parameters: dict
27+
28+
29+
class FunctionCall(BaseModel):
30+
"""
31+
A class representing a function call.
32+
33+
Attributes:
34+
name (str): The name of the function being called.
35+
arguments (Any): The arguments to be passed to the function.
36+
"""
37+
38+
name: str
39+
arguments: Any
40+
41+
42+
class ToolCallAction(BaseModel):
43+
id: str = ""
44+
function: FunctionCall
45+
46+
def llm_view(self, **kwargs) -> str:
47+
return self.model_dump_json(indent=2)
48+
49+
50+
class ToolSpec(BaseModel):
51+
"""
52+
ToolSpec is a model that represents a tool specification with a type and a function.
53+
54+
Attributes:
55+
type (Literal["function"]): The type of the tool, which is always "function".
56+
function (FunctionSpec): The specification of the function.
57+
"""
58+
59+
type: Literal["function"] = "function"
60+
function: FunctionSpec
61+
62+
def description(self) -> str:
63+
return f"{self.function.name} - {self.function.description}"
64+
65+
@classmethod
66+
def from_function(cls, function: Callable):
67+
"""
68+
Creates an instance of the class by validating the model from a given function.
69+
70+
Args:
71+
function (Callable): The function to be converted and validated.
72+
73+
Returns:
74+
(ToolSpec): An instance of the class with the validated model.
75+
"""
76+
return cls.model_validate(convert_to_openai_tool(function))
77+
78+
1279
class ToolsActionSet(AbstractActionSet):
1380
multiaction: bool = False
1481
strict: bool = False
@@ -49,19 +116,15 @@ def parse_action(cls, llm_output: str) -> ToolCallAction:
49116
if "<action>" in llm_output:
50117
content_dict, valid, retry_message = parse_html_tags_raise(llm_output, keys=["action"])
51118
if not valid or "action" not in content_dict:
52-
raise ValueError(
53-
f"Invalid action: llm_output: {llm_output}, retry_message: {retry_message}"
54-
)
119+
raise ValueError(f"Invalid action: llm_output: {llm_output}, retry_message: {retry_message}")
55120
action_str = content_dict["action"]
56121
else:
57122
action_str = llm_output
58123
try:
59124
action_dict = json.loads(action_str)
60125
except json.JSONDecodeError:
61126
raise ValueError(f"Failed to parse action: {action_str}")
62-
return ToolCallAction(
63-
function=FunctionCall(name=action_dict["name"], arguments=action_dict["arguments"])
64-
)
127+
return ToolCallAction(function=FunctionCall(name=action_dict["name"], arguments=action_dict["arguments"]))
65128

66129
def to_python_code(self, action) -> str:
67130
return action

src/agentlab/agents/tapeagent/agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
from tapeagents.tool_calling import ToolSpec
2828
from termcolor import colored
2929

30+
from agentlab.actions import ToolSpec as AgentlabToolSpec
3031
from agentlab.agents.agent_args import AgentArgs
31-
from agentlab.backends.browser.base import ToolSpec as AgentlabToolSpec
3232

3333
logger = logging.getLogger(__name__)
3434
logger.setLevel(logging.INFO)

src/agentlab/backends/browser/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from agentlab.backends.browser.base import BrowserBackend, FunctionCall, ToolCallAction, ToolSpec
1+
from agentlab.actions import FunctionCall, ToolCallAction, ToolSpec
2+
from agentlab.backends.browser.base import BrowserBackend
23
from agentlab.backends.browser.env import BrowserEnv, BrowserEnvArgs
34
from agentlab.backends.browser.mcp import MCPBrowserBackend, MCPClient
45
from agentlab.backends.browser.mcp_playwright import MCPPlaywright

src/agentlab/backends/browser/base.py

Lines changed: 2 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,77 +1,12 @@
11
import logging
22
from abc import ABC, abstractmethod
3-
from typing import Any, Callable, Literal
43

5-
from langchain_core.utils.function_calling import convert_to_openai_tool
64
from PIL import Image
75
from pydantic import BaseModel
86

9-
logger = logging.getLogger(__name__)
10-
11-
12-
class FunctionCall(BaseModel):
13-
"""
14-
A class representing a function call.
15-
16-
Attributes:
17-
name (str): The name of the function being called.
18-
arguments (Any): The arguments to be passed to the function.
19-
"""
20-
21-
name: str
22-
arguments: Any
23-
24-
25-
class FunctionSpec(BaseModel):
26-
"""
27-
A class representing the specification of a function.
28-
29-
Attributes:
30-
name (str): The name of the function.
31-
description (str): A brief description of the function.
32-
parameters (dict): A dictionary containing the parameters of the function.
33-
"""
34-
35-
name: str
36-
description: str
37-
parameters: dict
38-
7+
from agentlab.actions import ToolCallAction, ToolSpec
398

40-
class ToolCallAction(BaseModel):
41-
id: str = ""
42-
function: FunctionCall
43-
44-
def llm_view(self, **kwargs) -> str:
45-
return self.model_dump_json(indent=2)
46-
47-
48-
class ToolSpec(BaseModel):
49-
"""
50-
ToolSpec is a model that represents a tool specification with a type and a function.
51-
52-
Attributes:
53-
type (Literal["function"]): The type of the tool, which is always "function".
54-
function (FunctionSpec): The specification of the function.
55-
"""
56-
57-
type: Literal["function"] = "function"
58-
function: FunctionSpec
59-
60-
def description(self) -> str:
61-
return f"{self.function.name} - {self.function.description}"
62-
63-
@classmethod
64-
def from_function(cls, function: Callable):
65-
"""
66-
Creates an instance of the class by validating the model from a given function.
67-
68-
Args:
69-
function (Callable): The function to be converted and validated.
70-
71-
Returns:
72-
(ToolSpec): An instance of the class with the validated model.
73-
"""
74-
return cls.model_validate(convert_to_openai_tool(function))
9+
logger = logging.getLogger(__name__)
7510

7611

7712
class BrowserBackend(BaseModel, ABC):

src/agentlab/backends/browser/env.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from dataclasses import dataclass
44
from pathlib import Path
55

6-
from agentlab.actions import ToolsActionSet
7-
from agentlab.backends.browser.base import BrowserBackend, ToolCallAction, ToolSpec
6+
from agentlab.actions import ToolCallAction, ToolsActionSet, ToolSpec
7+
from agentlab.backends.browser.base import BrowserBackend
88
from agentlab.benchmarks.abstract_env import AbstractEnv, AbstractEnvArgs
99
from agentlab.benchmarks.web_task import AbstractWebTask
1010

src/agentlab/backends/browser/mcp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
from mcp import Tool as MCPTool
1111
from mcp.types import CallToolResult, ImageContent, TextContent
1212

13-
from agentlab.backends.browser.base import BrowserBackend, FunctionSpec, ToolCallAction, ToolSpec
13+
from agentlab.actions import FunctionSpec, ToolCallAction, ToolSpec
14+
from agentlab.backends.browser.base import BrowserBackend
1415

1516
logger = logging.getLogger(__name__)
1617

src/agentlab/backends/browser/mcp_playwright.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
from PIL import Image
66

7-
from agentlab.backends.browser.mcp import MCPBrowserBackend, ToolCallAction
7+
from agentlab.actions import ToolCallAction
8+
from agentlab.backends.browser.mcp import MCPBrowserBackend
89

910
logger = logging.getLogger(__name__)
1011

src/agentlab/backends/browser/playwright.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from PIL import Image
77
from playwright.async_api import Browser, Page, async_playwright
88

9-
from agentlab.backends.browser.base import BrowserBackend, ToolCallAction, ToolSpec
9+
from agentlab.actions import ToolCallAction, ToolSpec
10+
from agentlab.backends.browser.base import BrowserBackend
1011

1112
logger = logging.getLogger(__name__)
1213

src/agentlab/benchmarks/web_task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from pydantic import BaseModel
44

5-
from agentlab.backends.browser.base import ToolSpec
5+
from agentlab.actions import ToolSpec
66

77

88
class AbstractWebTask(BaseModel):

0 commit comments

Comments
 (0)