Skip to content

Commit 8be1174

Browse files
committed
revert tapeagent changes
1 parent 212c0f4 commit 8be1174

File tree

1 file changed

+11
-180
lines changed

1 file changed

+11
-180
lines changed
Lines changed: 11 additions & 180 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,15 @@
11
import logging
2-
import tempfile
32
from dataclasses import dataclass
4-
from typing import Any, Literal
3+
from typing import Literal
54

65
import bgym
76
import hydra
8-
from litellm import ChatCompletionThinkingBlock
97
from omegaconf import DictConfig
10-
from PIL import Image
118
from pydantic import Field
129
from tapeagents.agent import Agent
13-
from tapeagents.core import (
14-
Action,
15-
ControlFlow,
16-
LLMOutputParsingFailureAction,
17-
Observation,
18-
SetNextNode,
19-
StopStep,
20-
TapeMetadata,
21-
Thought,
22-
)
10+
from tapeagents.core import Action, Observation, StopStep, TapeMetadata, Thought
2311
from tapeagents.core import Tape as BaseTape
24-
from tapeagents.llms import LLMStream
25-
from tapeagents.nodes import FatalError, StandardNode
26-
from tapeagents.steps import ImageObservation
27-
from tapeagents.tool_calling import ToolSpec
28-
from termcolor import colored
2912

30-
from agentlab.actions import ToolSpec as AgentlabToolSpec
3113
from agentlab.agents.agent_args import AgentArgs
3214

3315
logger = logging.getLogger(__name__)
@@ -44,65 +26,10 @@ class ExtendedMetadata(TapeMetadata):
4426
other: dict = {}
4527

4628

47-
class AgentResponse(Thought):
48-
kind: Literal["agent_response"] = "agent_response"
49-
response: str
50-
51-
def llm_view(self, **kwargs) -> str:
52-
return self.response
53-
54-
55-
class AgentThinking(Thought):
56-
kind: Literal["agent_thinking"] = "agent_thinking"
57-
thinking: str
58-
59-
def llm_view(self, **kwargs) -> str:
60-
return self.thinking
61-
62-
6329
class Tape(BaseTape):
6430
metadata: ExtendedMetadata = Field(default_factory=ExtendedMetadata) # type: ignore
6531

6632

67-
class ToolCallNode(StandardNode):
68-
use_known_actions: bool = True
69-
use_function_calls: bool = True
70-
71-
def generate_steps(self, agent: Agent, tape: Tape, llm_stream: LLMStream):
72-
new_steps = []
73-
for event in llm_stream:
74-
if event.output.get("reasoning_content"):
75-
logger.info(colored(f"LLM reasoning:\n{event.output.reasoning_content}", "yellow"))
76-
new_steps.append(AgentThinking(thinking=event.output.reasoning_content))
77-
if event.output.get("thinking_blocks"):
78-
for block in event.output.thinking_blocks:
79-
if isinstance(block, ChatCompletionThinkingBlock):
80-
logger.info(colored(f"LLM thinking block:\n{block}", "yellow"))
81-
new_steps.append(AgentThinking(thinking=block.content))
82-
if event.output.content:
83-
logger.info(colored(f"LLM output:\n{event.output.content}", "cyan"))
84-
new_steps.append(AgentResponse(response=event.output.content))
85-
if event.output.tool_calls:
86-
logger.info(colored(f"LLM tool calls:\n{event.output.tool_calls}", "magenta"))
87-
new_steps += [
88-
self.tool_call_to_step(agent, tool_call)
89-
for tool_call in event.output.tool_calls
90-
]
91-
for step in new_steps:
92-
yield step
93-
if isinstance(step, LLMOutputParsingFailureAction):
94-
yield SetNextNode(next_node=self.name) # loop to the same node to retry
95-
break
96-
if not new_steps:
97-
raise FatalError("No completions!")
98-
if (
99-
self.next_node
100-
and not isinstance(new_steps[-1], StopStep)
101-
and not any(isinstance(step, SetNextNode) for step in new_steps)
102-
):
103-
yield SetNextNode(next_node=self.next_node)
104-
105-
10633
def load_config(config_name: str) -> DictConfig:
10734
with hydra.initialize(config_path="conf", version_base="1.1"):
10835
config = hydra.compose(config_name=config_name)
@@ -113,20 +40,8 @@ def load_config(config_name: str) -> DictConfig:
11340
class TapeAgentArgs(AgentArgs):
11441
config: DictConfig = None # type: ignore
11542

116-
def make_agent(self, actions: tuple[ToolSpec, ...] | None) -> bgym.Agent:
117-
if actions is None:
118-
agent = hydra.utils.instantiate(self.config.agent)
119-
else:
120-
tapeagents_actions = [
121-
ToolSpec(**tool.model_dump()) if isinstance(tool, AgentlabToolSpec) else tool
122-
for tool in actions
123-
]
124-
tools_description = "\n".join([action.description() for action in actions])
125-
agent = hydra.utils.instantiate(
126-
self.config.agent,
127-
known_actions=tapeagents_actions,
128-
tools_description=tools_description,
129-
)
43+
def make_agent(self) -> bgym.Agent:
44+
agent: Agent = hydra.utils.instantiate(self.config.agent)
13045
return TapeAgent(agent=agent)
13146

13247

@@ -144,62 +59,6 @@ class DictObservation(Observation):
14459
content: str
14560

14661

147-
class MarkdownObservation(Observation):
148-
def llm_view(self, **kwargs) -> str:
149-
return f"## Markdown:\n{self.content}"
150-
151-
def short_view(self, max_chars: int = 100) -> str:
152-
return self.llm_view()[:max_chars]
153-
154-
155-
class GoalObservation(MarkdownObservation):
156-
"""
157-
Contains task goal
158-
"""
159-
160-
kind: Literal["goal_observation"] = "goal_observation" # type: ignore
161-
goal: str
162-
163-
def llm_view(self, **kwargs) -> str:
164-
return f"## Goal:\n{self.goal}"
165-
166-
167-
class HTMLPage(MarkdownObservation):
168-
"""
169-
Contains page content
170-
"""
171-
172-
kind: Literal["html_page"] = "html_page"
173-
html: str
174-
175-
def llm_view(self, **kwargs) -> str:
176-
return f"## Page Content:\n{self.html}"
177-
178-
179-
class AXTreePage(MarkdownObservation):
180-
"""
181-
Contains accessibility tree
182-
"""
183-
184-
kind: Literal["ax_tree_page"] = "ax_tree_page"
185-
axtree: str
186-
187-
def llm_view(self, **kwargs) -> str:
188-
return f"## Accessibility Tree:\n{self.axtree}"
189-
190-
191-
class ActionResult(MarkdownObservation):
192-
"""
193-
Contains action result
194-
"""
195-
196-
kind: Literal["action_result"] = "action_result"
197-
result: str
198-
199-
def llm_view(self, **kwargs) -> str:
200-
return f"## Action Result:\n{self.result}"
201-
202-
20362
class TapeAgent(bgym.Agent):
20463
agent: Agent
20564
tape: Tape
@@ -209,50 +68,23 @@ def __init__(self, agent: Agent):
20968
self.agent = agent
21069
self.tape = Tape(steps=[])
21170

212-
def obs_preprocessor(self, obs: Any) -> list[Observation]:
213-
return obs
214-
215-
def obs_to_steps(self, obs: Observation | list[Observation] | dict) -> list[Observation]:
71+
def obs_preprocessor(self, obs: Observation | list[Observation]) -> list[Observation]:
21672
if isinstance(obs, Observation):
21773
obs = [obs]
218-
if isinstance(obs, dict):
219-
obs_steps = []
220-
if obs.get("goal_object"):
221-
obs_steps.append(GoalObservation(goal=obs["goal_object"][0]["text"]))
222-
if obs.get("action_result"):
223-
obs_steps.append(ActionResult(result=obs["action_result"]))
224-
if obs.get("pruned_html"):
225-
obs_steps.append(HTMLPage(html=obs["pruned_html"]))
226-
if obs.get("axtree_txt"):
227-
obs_steps.append(AXTreePage(axtree=obs["axtree_txt"]))
228-
if obs.get("screenshot"):
229-
if isinstance(obs["screenshot"], Image.Image):
230-
tmp_image_path = tempfile.mktemp(suffix=".png")
231-
obs["screenshot"].save(tmp_image_path)
232-
obs_steps.append(ImageObservation(image_path=tmp_image_path))
233-
else:
234-
raise ValueError(f"Expected Image.Image, got {type(obs['screenshot'])}")
235-
if obs.get("last_action_error"):
236-
obs_steps.append(ActionResult(result=f"Action error:\n{obs['last_action_error']}"))
237-
assert len(obs_steps) > 0, f"Unknown dict observation, keys: {obs.keys()}"
238-
obs = obs_steps
23974
assert isinstance(obs, list), f"Expected list of Observations, got {type(obs)}"
240-
obs_view = "\n".join([o.short_view() for o in obs])
241-
logger.info(colored(f"Observations:\n{obs_view}", "green"))
75+
logger.info(f"Observations: {[type(o).__name__ for o in obs]}")
24276
return obs
24377

244-
def get_action(
245-
self, obs: Observation | list[Observation] | dict
246-
) -> tuple[Action, TapeAgentInfo]:
247-
self.tape += self.obs_to_steps(obs)
78+
def get_action(self, obs: Observation | list[Observation]) -> tuple[Action, TapeAgentInfo]:
79+
self.tape += obs # type: ignore
24880
thoughts: list[Thought] = []
24981
action = None
25082
while not action:
25183
for event in self.agent.run(self.tape):
25284
if not event.step:
25385
continue
25486
self.tape = self.tape.append(event.step)
255-
if isinstance(event.step, Thought) and not isinstance(event.step, ControlFlow):
87+
if isinstance(event.step, Thought):
25688
thoughts.append(event.step)
25789
logger.info(f"Thought: {event.step.llm_view()}")
25890
elif isinstance(event.step, Action) and not action: # we use first action only
@@ -262,11 +94,10 @@ def get_action(
26294
# there could be control flow steps for switching nodes and if clauses
26395
logger.info(f"Other step: {type(event.step)}")
26496
logger.info(f"Tape after run: ({len(self.tape)}) {[type(s).__name__ for s in self.tape]}")
265-
think_str = "\n".join([t.llm_view() for t in thoughts])
266-
return (action, {"thoughts": thoughts, "think": think_str})
97+
return (action, TapeAgentInfo(thoughts=thoughts))
26798

26899
@property
269100
def final_tape(self) -> Tape:
270101
truncated = not any([isinstance(s, StopStep) for s in self.tape.steps])
271102
self.tape.metadata = ExtendedMetadata(author=self.agent.name, truncated=truncated)
272-
return self.tape
103+
return self.tape

0 commit comments

Comments
 (0)