|
| 1 | +from typing import Any, Literal, Union |
| 2 | + |
| 3 | +from pydantic import Annotated, Field, TypeAdapter |
| 4 | +from tapeagents.core import Action, Observation, Tape |
| 5 | +from tapeagents.environment import ToolCollectionEnvironment |
| 6 | +from tapeagents.tools.base import Multitool, Tool |
| 7 | + |
| 8 | +from agentlab.benchmarks.abstract_env import AbstractEnv |
| 9 | + |
| 10 | +EnvTape = Tape[None, Action | Observation] |
| 11 | + |
| 12 | + |
| 13 | +class FunctionCall(Action): |
| 14 | + kind: Literal["function_call_action"] = ["function_call_action"] |
| 15 | + function_name: str |
| 16 | + args: list[Any] | None |
| 17 | + kwargs: dict[str, Any] | None |
| 18 | + |
| 19 | + |
| 20 | +class FunctionCallResult(Observation): |
| 21 | + kind: Literal["function_call_result"] = ["function_call_result"] |
| 22 | + result: Any |
| 23 | + |
| 24 | + |
| 25 | +class SimpleFunctionCallTool(Tool): |
| 26 | + action = FunctionCall |
| 27 | + observation = FunctionCallResult |
| 28 | + function: callable |
| 29 | + function_name: str = "" |
| 30 | + |
| 31 | + def model_post_init(self, __context): |
| 32 | + function_name = getattr(self.function, "__name__", "") |
| 33 | + if not function_name and not self.function_name: |
| 34 | + raise ValueError("Function has no name, function_name must be provided") |
| 35 | + |
| 36 | + def execute_action(self, action: FunctionCall) -> FunctionCallResult: |
| 37 | + if not self.function_name == action.function_name: |
| 38 | + raise ValueError( |
| 39 | + f"Unexpected function action {action.function_name}, expected {self.function_name}" |
| 40 | + ) |
| 41 | + result = self.function(*action.args, **action.kwargs) |
| 42 | + return FunctionCallResult(result=result) |
| 43 | + |
| 44 | + |
| 45 | +class MultiToolGym(AbstractEnv): |
| 46 | + def __init__(self, tools: list[Tool | Multitool]): |
| 47 | + self._env = ToolCollectionEnvironment(tools) |
| 48 | + self._actions = self._env.actions() |
| 49 | + self._actions_parser: TypeAdapter = TypeAdapter( |
| 50 | + Annotated[Union[self._actions], Field(discriminator="kind")] |
| 51 | + ) |
| 52 | + self.reset() |
| 53 | + |
| 54 | + def reset(self, seed=None): |
| 55 | + self._tape: EnvTape = EnvTape(steps=[]) |
| 56 | + |
| 57 | + def step(self, action: str): |
| 58 | + try: |
| 59 | + action_step = self._actions_parser.validate_json(action) |
| 60 | + except Exception: |
| 61 | + raise ValueError("Action must be a valid JSON dict") |
| 62 | + assert isinstance(action_step, Action), "{action_step.kind} is not an Action" |
| 63 | + self._tape += [action_step] |
| 64 | + self._tape = self._env.react(self._tape) |
| 65 | + observation_step: Observation = self._tape.steps[-1] |
| 66 | + reward = self.calculate_reward() |
| 67 | + terminated = False |
| 68 | + truncated = False |
| 69 | + env_info = {"step_metadata": observation_step.metadata} |
| 70 | + return observation_step.llm_dict(), reward, terminated, truncated, env_info |
| 71 | + |
| 72 | + def calculate_reward(self) -> float: |
| 73 | + return 0.0 |
| 74 | + |
| 75 | + def close(self): |
| 76 | + self._env.close() |
0 commit comments