11import logging
2+ import tempfile
23from dataclasses import dataclass
34from typing import Literal
45
56import bgym
67import hydra
8+ from litellm import ChatCompletionThinkingBlock
79from omegaconf import DictConfig
10+ from PIL import Image
811from pydantic import Field
912from tapeagents .agent import Agent
10- from tapeagents .core import Action , Observation , StopStep , TapeMetadata , Thought
13+ from tapeagents .core import (
14+ Action ,
15+ LLMOutputParsingFailureAction ,
16+ Observation ,
17+ SetNextNode ,
18+ StopStep ,
19+ TapeMetadata ,
20+ Thought ,
21+ )
1122from tapeagents .core import Tape as BaseTape
23+ from tapeagents .llms import LLMStream
24+ from tapeagents .nodes import FatalError , StandardNode
25+ from tapeagents .steps import ImageObservation
1226from tapeagents .tool_calling import ToolSpec
27+ from termcolor import colored
1328
1429from agentlab .agents .agent_args import AgentArgs
30+ from agentlab .backends .browser .base import ToolSpec as AgentlabToolSpec
1531
1632logger = logging .getLogger (__name__ )
1733logger .setLevel (logging .INFO )
@@ -27,10 +43,59 @@ class ExtendedMetadata(TapeMetadata):
2743 other : dict = {}
2844
2945
46+ class AgentResponse (Thought ):
47+ kind : Literal ["agent_response" ] = "agent_response"
48+ response : str
49+
50+
51+ class AgentThinking (Thought ):
52+ kind : Literal ["agent_thinking" ] = "agent_thinking"
53+ thinking : str
54+
55+
3056class Tape (BaseTape ):
3157 metadata : ExtendedMetadata = Field (default_factory = ExtendedMetadata ) # type: ignore
3258
3359
60+ class ToolCallNode (StandardNode ):
61+ use_known_actions : bool = True
62+ use_function_calls : bool = True
63+
64+ def generate_steps (self , agent : Agent , tape : Tape , llm_stream : LLMStream ):
65+ new_steps = []
66+ for event in llm_stream :
67+ if event .output .get ("reasoning_content" ):
68+ logger .info (colored (f"LLM reasoning:\n { event .output .reasoning_content } " , "yellow" ))
69+ new_steps .append (AgentThinking (thinking = event .output .reasoning_content ))
70+ if event .output .get ("thinking_blocks" ):
71+ for block in event .output .thinking_blocks :
72+ if isinstance (block , ChatCompletionThinkingBlock ):
73+ logger .info (colored (f"LLM thinking block:\n { block } " , "yellow" ))
74+ new_steps .append (AgentThinking (thinking = block .content ))
75+ if event .output .content :
76+ logger .info (colored (f"LLM output:\n { event .output .content } " , "cyan" ))
77+ new_steps .append (AgentResponse (response = event .output .content ))
78+ if event .output .tool_calls :
79+ logger .info (colored (f"LLM tool calls:\n { event .output .tool_calls } " , "magenta" ))
80+ new_steps += [
81+ self .tool_call_to_step (agent , tool_call )
82+ for tool_call in event .output .tool_calls
83+ ]
84+ for step in new_steps :
85+ yield step
86+ if isinstance (step , LLMOutputParsingFailureAction ):
87+ yield SetNextNode (next_node = self .name ) # loop to the same node to retry
88+ break
89+ if not new_steps :
90+ raise FatalError ("No completions!" )
91+ if (
92+ self .next_node
93+ and not isinstance (new_steps [- 1 ], StopStep )
94+ and not any (isinstance (step , SetNextNode ) for step in new_steps )
95+ ):
96+ yield SetNextNode (next_node = self .next_node )
97+
98+
3499def load_config (config_name : str ) -> DictConfig :
35100 with hydra .initialize (config_path = "conf" , version_base = "1.1" ):
36101 config = hydra .compose (config_name = config_name )
@@ -45,8 +110,16 @@ def make_agent(self, actions: tuple[ToolSpec, ...] | None) -> bgym.Agent:
45110 if actions is None :
46111 agent = hydra .utils .instantiate (self .config .agent )
47112 else :
113+ tapeagents_actions = [
114+ ToolSpec (** tool .model_dump ()) if isinstance (tool , AgentlabToolSpec ) else tool
115+ for tool in actions
116+ ]
48117 tools_description = "\n " .join ([action .description () for action in actions ])
49- agent = hydra .utils .instantiate (self .config .agent , known_actions = actions , tools_description = tools_description )
118+ agent = hydra .utils .instantiate (
119+ self .config .agent ,
120+ known_actions = tapeagents_actions ,
121+ tools_description = tools_description ,
122+ )
50123 return TapeAgent (agent = agent )
51124
52125
@@ -64,6 +137,62 @@ class DictObservation(Observation):
64137 content : str
65138
66139
140+ class MarkdownObservation (Observation ):
141+ def llm_view (self , ** kwargs ) -> str :
142+ return f"## Markdown:\n { self .content } "
143+
144+ def short_view (self , max_chars : int = 100 ) -> str :
145+ return self .llm_view ()[:max_chars ]
146+
147+
148+ class GoalObservation (MarkdownObservation ):
149+ """
150+ Contains task goal
151+ """
152+
153+ kind : Literal ["goal_observation" ] = "goal_observation" # type: ignore
154+ goal : str
155+
156+ def llm_view (self , ** kwargs ) -> str :
157+ return f"## Goal:\n { self .goal } "
158+
159+
160+ class HTMLPage (MarkdownObservation ):
161+ """
162+ Contains page content
163+ """
164+
165+ kind : Literal ["html_page" ] = "html_page"
166+ html : str
167+
168+ def llm_view (self , ** kwargs ) -> str :
169+ return f"## Page Content:\n { self .html } "
170+
171+
172+ class AXTreePage (MarkdownObservation ):
173+ """
174+ Contains accessibility tree
175+ """
176+
177+ kind : Literal ["ax_tree_page" ] = "ax_tree_page"
178+ axtree : str
179+
180+ def llm_view (self , ** kwargs ) -> str :
181+ return f"## Accessibility Tree:\n { self .axtree } "
182+
183+
184+ class ActionResult (MarkdownObservation ):
185+ """
186+ Contains action result
187+ """
188+
189+ kind : Literal ["action_result" ] = "action_result"
190+ result : str
191+
192+ def llm_view (self , ** kwargs ) -> str :
193+ return f"## Action Result:\n { self .result } "
194+
195+
67196class TapeAgent (bgym .Agent ):
68197 agent : Agent
69198 tape : Tape
@@ -73,11 +202,33 @@ def __init__(self, agent: Agent):
73202 self .agent = agent
74203 self .tape = Tape (steps = [])
75204
76- def obs_preprocessor (self , obs : Observation | list [Observation ]) -> list [Observation ]:
205+ def obs_preprocessor (self , obs : Observation | list [Observation ] | dict ) -> list [Observation ]:
77206 if isinstance (obs , Observation ):
78207 obs = [obs ]
208+ if isinstance (obs , dict ):
209+ obs_steps = []
210+ if obs .get ("goal_object" ):
211+ obs_steps .append (GoalObservation (goal = obs ["goal_object" ][0 ]["text" ]))
212+ if obs .get ("action_result" ):
213+ obs_steps .append (ActionResult (result = obs ["action_result" ]))
214+ if obs .get ("pruned_html" ):
215+ obs_steps .append (HTMLPage (html = obs ["pruned_html" ]))
216+ if obs .get ("axtree_txt" ):
217+ obs_steps .append (AXTreePage (axtree = obs ["axtree_txt" ]))
218+ if obs .get ("screenshot" ):
219+ if isinstance (obs ["screenshot" ], Image .Image ):
220+ tmp_image_path = tempfile .mktemp (suffix = ".png" )
221+ obs ["screenshot" ].save (tmp_image_path )
222+ obs_steps .append (ImageObservation (image_path = tmp_image_path ))
223+ else :
224+ raise ValueError (f"Expected Image.Image, got { type (obs ['screenshot' ])} " )
225+ if obs .get ("last_action_error" ):
226+ obs_steps .append (ActionResult (result = f"Action error:\n { obs ['last_action_error' ]} " ))
227+ assert len (obs_steps ) > 0 , f"Unknown dict observation, keys: { obs .keys ()} "
228+ obs = obs_steps
79229 assert isinstance (obs , list ), f"Expected list of Observations, got { type (obs )} "
80- logger .info (f"Observations: { [type (o ).__name__ for o in obs ]} " )
230+ obs_view = "\n " .join ([o .short_view () for o in obs ])
231+ logger .info (colored (f"Observations:\n { obs_view } " , "green" ))
81232 return obs
82233
83234 def get_action (self , obs : Observation | list [Observation ]) -> tuple [Action , TapeAgentInfo ]:
0 commit comments