Skip to content

Commit f2c480a

Browse files
committed
fix tapeagent
1 parent 1f090c2 commit f2c480a

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

src/agentlab/agents/tapeagent/agent.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import tempfile
33
from dataclasses import dataclass
4-
from typing import Literal
4+
from typing import Any, Literal
55

66
import bgym
77
import hydra
@@ -47,11 +47,17 @@ class AgentResponse(Thought):
4747
kind: Literal["agent_response"] = "agent_response"
4848
response: str
4949

50+
def llm_view(self, **kwargs) -> str:
51+
return self.response
52+
5053

5154
class AgentThinking(Thought):
5255
kind: Literal["agent_thinking"] = "agent_thinking"
5356
thinking: str
5457

58+
def llm_view(self, **kwargs) -> str:
59+
return self.thinking
60+
5561

5662
class Tape(BaseTape):
5763
metadata: ExtendedMetadata = Field(default_factory=ExtendedMetadata) # type: ignore
@@ -202,7 +208,10 @@ def __init__(self, agent: Agent):
202208
self.agent = agent
203209
self.tape = Tape(steps=[])
204210

205-
def obs_preprocessor(self, obs: Observation | list[Observation] | dict) -> list[Observation]:
211+
def obs_preprocessor(self, obs: Any) -> list[Observation]:
212+
return obs
213+
214+
def obs_to_steps(self, obs: Observation | list[Observation] | dict) -> list[Observation]:
206215
if isinstance(obs, Observation):
207216
obs = [obs]
208217
if isinstance(obs, dict):
@@ -231,8 +240,10 @@ def obs_preprocessor(self, obs: Observation | list[Observation] | dict) -> list[
231240
logger.info(colored(f"Observations:\n{obs_view}", "green"))
232241
return obs
233242

234-
def get_action(self, obs: Observation | list[Observation]) -> tuple[Action, TapeAgentInfo]:
235-
self.tape += obs # type: ignore
243+
def get_action(
244+
self, obs: Observation | list[Observation] | dict
245+
) -> tuple[Action, TapeAgentInfo]:
246+
self.tape += self.obs_to_steps(obs)
236247
thoughts: list[Thought] = []
237248
action = None
238249
while not action:
@@ -250,7 +261,8 @@ def get_action(self, obs: Observation | list[Observation]) -> tuple[Action, Tape
250261
# there could be control flow steps for switching nodes and if clauses
251262
logger.info(f"Other step: {type(event.step)}")
252263
logger.info(f"Tape after run: ({len(self.tape)}) {[type(s).__name__ for s in self.tape]}")
253-
return (action, TapeAgentInfo(thoughts=thoughts))
264+
think_str = "\n".join([t.llm_view() for t in thoughts])
265+
return (action, {"thoughts": thoughts, "think": think_str})
254266

255267
@property
256268
def final_tape(self) -> Tape:

0 commit comments

Comments
 (0)