Skip to content

Commit 90150ab

Browse files
committed
multitool environment draft
1 parent 11071e9 commit 90150ab

File tree

2 files changed

+82
-2
lines changed

2 files changed

+82
-2
lines changed

src/agentlab/benchmarks/abstract_env.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
import gym
21
from abc import ABC, abstractmethod
32

3+
import gym
4+
45

56
class AbstractEnvArgs(ABC):
67
"""Easily serialiazable class to store the arguments of an environment"""
@@ -21,7 +22,6 @@ def make_env(self, action_mapping, exp_dir, exp_task_kwargs) -> "AbstractEnv":
2122

2223

2324
class AbstractEnv(gym.Env, ABC):
24-
2525
@abstractmethod
2626
def reset(self, seed: int = None) -> tuple[dict[str, any], dict[str, any]]:
2727
"""Reset the environment to the initial state, ready for an agent to start a new episode.
@@ -57,3 +57,7 @@ def step(self, action: str):
5757
@abstractmethod
5858
def close(self):
5959
"""Close any resources used by the environment"""
60+
61+
@abstractmethod
62+
def calculate_reward(self) -> float:
63+
"""Calculate the reward obtained in the last step"""
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from typing import Any, Literal, Union
2+
3+
from pydantic import Annotated, Field, TypeAdapter
4+
from tapeagents.core import Action, Observation, Tape
5+
from tapeagents.environment import ToolCollectionEnvironment
6+
from tapeagents.tools.base import Multitool, Tool
7+
8+
from agentlab.benchmarks.abstract_env import AbstractEnv
9+
10+
EnvTape = Tape[None, Action | Observation]
11+
12+
13+
class FunctionCall(Action):
14+
kind: Literal["function_call_action"] = ["function_call_action"]
15+
function_name: str
16+
args: list[Any] | None
17+
kwargs: dict[str, Any] | None
18+
19+
20+
class FunctionCallResult(Observation):
21+
kind: Literal["function_call_result"] = ["function_call_result"]
22+
result: Any
23+
24+
25+
class SimpleFunctionCallTool(Tool):
26+
action = FunctionCall
27+
observation = FunctionCallResult
28+
function: callable
29+
function_name: str = ""
30+
31+
def model_post_init(self, __context):
32+
function_name = getattr(self.function, "__name__", "")
33+
if not function_name and not self.function_name:
34+
raise ValueError("Function has no name, function_name must be provided")
35+
36+
def execute_action(self, action: FunctionCall) -> FunctionCallResult:
37+
if not self.function_name == action.function_name:
38+
raise ValueError(
39+
f"Unexpected function action {action.function_name}, expected {self.function_name}"
40+
)
41+
result = self.function(*action.args, **action.kwargs)
42+
return FunctionCallResult(result=result)
43+
44+
45+
class MultiToolGym(AbstractEnv):
46+
def __init__(self, tools: list[Tool | Multitool]):
47+
self._env = ToolCollectionEnvironment(tools)
48+
self._actions = self._env.actions()
49+
self._actions_parser: TypeAdapter = TypeAdapter(
50+
Annotated[Union[self._actions], Field(discriminator="kind")]
51+
)
52+
self.reset()
53+
54+
def reset(self, seed=None):
55+
self._tape: EnvTape = EnvTape(steps=[])
56+
57+
def step(self, action: str):
58+
try:
59+
action_step = self._actions_parser.validate_json(action)
60+
except Exception:
61+
raise ValueError("Action must be a valid JSON dict")
62+
assert isinstance(action_step, Action), "{action_step.kind} is not an Action"
63+
self._tape += [action_step]
64+
self._tape = self._env.react(self._tape)
65+
observation_step: Observation = self._tape.steps[-1]
66+
reward = self.calculate_reward()
67+
terminated = False
68+
truncated = False
69+
env_info = {"step_metadata": observation_step.metadata}
70+
return observation_step.llm_dict(), reward, terminated, truncated, env_info
71+
72+
def calculate_reward(self) -> float:
73+
return 0.0
74+
75+
def close(self):
76+
self._env.close()

0 commit comments

Comments
 (0)