|
1 | 1 | import json |
2 | 2 | import logging |
| 3 | +from typing import Any, Callable, Literal |
3 | 4 |
|
4 | 5 | from bgym import AbstractActionSet |
| 6 | +from langchain_core.utils.function_calling import convert_to_openai_tool |
| 7 | +from pydantic import BaseModel |
5 | 8 |
|
6 | | -from agentlab.backends.browser.base import FunctionCall, ToolCallAction, ToolSpec |
7 | 9 | from agentlab.llm.llm_utils import parse_html_tags_raise |
8 | 10 |
|
9 | 11 | logger = logging.getLogger(__name__) |
10 | 12 |
|
11 | 13 |
|
| 14 | +class FunctionSpec(BaseModel): |
| 15 | + """ |
| 16 | + A class representing the specification of a function. |
| 17 | +
|
| 18 | + Attributes: |
| 19 | + name (str): The name of the function. |
| 20 | + description (str): A brief description of the function. |
| 21 | + parameters (dict): A dictionary containing the parameters of the function. |
| 22 | + """ |
| 23 | + |
| 24 | + name: str |
| 25 | + description: str |
| 26 | + parameters: dict |
| 27 | + |
| 28 | + |
| 29 | +class FunctionCall(BaseModel): |
| 30 | + """ |
| 31 | + A class representing a function call. |
| 32 | +
|
| 33 | + Attributes: |
| 34 | + name (str): The name of the function being called. |
| 35 | + arguments (Any): The arguments to be passed to the function. |
| 36 | + """ |
| 37 | + |
| 38 | + name: str |
| 39 | + arguments: Any |
| 40 | + |
| 41 | + |
| 42 | +class ToolCallAction(BaseModel): |
| 43 | + id: str = "" |
| 44 | + function: FunctionCall |
| 45 | + |
| 46 | + def llm_view(self, **kwargs) -> str: |
| 47 | + return self.model_dump_json(indent=2) |
| 48 | + |
| 49 | + |
| 50 | +class ToolSpec(BaseModel): |
| 51 | + """ |
| 52 | + ToolSpec is a model that represents a tool specification with a type and a function. |
| 53 | +
|
| 54 | + Attributes: |
| 55 | + type (Literal["function"]): The type of the tool, which is always "function". |
| 56 | + function (FunctionSpec): The specification of the function. |
| 57 | + """ |
| 58 | + |
| 59 | + type: Literal["function"] = "function" |
| 60 | + function: FunctionSpec |
| 61 | + |
| 62 | + def description(self) -> str: |
| 63 | + return f"{self.function.name} - {self.function.description}" |
| 64 | + |
| 65 | + @classmethod |
| 66 | + def from_function(cls, function: Callable): |
| 67 | + """ |
| 68 | + Creates an instance of the class by validating the model from a given function. |
| 69 | +
|
| 70 | + Args: |
| 71 | + function (Callable): The function to be converted and validated. |
| 72 | +
|
| 73 | + Returns: |
| 74 | + (ToolSpec): An instance of the class with the validated model. |
| 75 | + """ |
| 76 | + return cls.model_validate(convert_to_openai_tool(function)) |
| 77 | + |
| 78 | + |
12 | 79 | class ToolsActionSet(AbstractActionSet): |
13 | 80 | multiaction: bool = False |
14 | 81 | strict: bool = False |
@@ -49,19 +116,15 @@ def parse_action(cls, llm_output: str) -> ToolCallAction: |
49 | 116 | if "<action>" in llm_output: |
50 | 117 | content_dict, valid, retry_message = parse_html_tags_raise(llm_output, keys=["action"]) |
51 | 118 | if not valid or "action" not in content_dict: |
52 | | - raise ValueError( |
53 | | - f"Invalid action: llm_output: {llm_output}, retry_message: {retry_message}" |
54 | | - ) |
| 119 | + raise ValueError(f"Invalid action: llm_output: {llm_output}, retry_message: {retry_message}") |
55 | 120 | action_str = content_dict["action"] |
56 | 121 | else: |
57 | 122 | action_str = llm_output |
58 | 123 | try: |
59 | 124 | action_dict = json.loads(action_str) |
60 | 125 | except json.JSONDecodeError: |
61 | 126 | raise ValueError(f"Failed to parse action: {action_str}") |
62 | | - return ToolCallAction( |
63 | | - function=FunctionCall(name=action_dict["name"], arguments=action_dict["arguments"]) |
64 | | - ) |
| 127 | + return ToolCallAction(function=FunctionCall(name=action_dict["name"], arguments=action_dict["arguments"])) |
65 | 128 |
|
66 | 129 | def to_python_code(self, action) -> str: |
67 | 130 | return action |
0 commit comments