11import logging
2- import tempfile
32from dataclasses import dataclass
4- from typing import Any , Literal
3+ from typing import Literal
54
65import bgym
76import hydra
8- from litellm import ChatCompletionThinkingBlock
97from omegaconf import DictConfig
10- from PIL import Image
118from pydantic import Field
129from tapeagents .agent import Agent
13- from tapeagents .core import (
14- Action ,
15- ControlFlow ,
16- LLMOutputParsingFailureAction ,
17- Observation ,
18- SetNextNode ,
19- StopStep ,
20- TapeMetadata ,
21- Thought ,
22- )
10+ from tapeagents .core import Action , Observation , StopStep , TapeMetadata , Thought
2311from tapeagents .core import Tape as BaseTape
24- from tapeagents .llms import LLMStream
25- from tapeagents .nodes import FatalError , StandardNode
26- from tapeagents .steps import ImageObservation
27- from tapeagents .tool_calling import ToolSpec
28- from termcolor import colored
2912
30- from agentlab .actions import ToolSpec as AgentlabToolSpec
3113from agentlab .agents .agent_args import AgentArgs
3214
3315logger = logging .getLogger (__name__ )
@@ -44,65 +26,10 @@ class ExtendedMetadata(TapeMetadata):
4426 other : dict = {}
4527
4628
47- class AgentResponse (Thought ):
48- kind : Literal ["agent_response" ] = "agent_response"
49- response : str
50-
51- def llm_view (self , ** kwargs ) -> str :
52- return self .response
53-
54-
55- class AgentThinking (Thought ):
56- kind : Literal ["agent_thinking" ] = "agent_thinking"
57- thinking : str
58-
59- def llm_view (self , ** kwargs ) -> str :
60- return self .thinking
61-
62-
6329class Tape (BaseTape ):
6430 metadata : ExtendedMetadata = Field (default_factory = ExtendedMetadata ) # type: ignore
6531
6632
67- class ToolCallNode (StandardNode ):
68- use_known_actions : bool = True
69- use_function_calls : bool = True
70-
71- def generate_steps (self , agent : Agent , tape : Tape , llm_stream : LLMStream ):
72- new_steps = []
73- for event in llm_stream :
74- if event .output .get ("reasoning_content" ):
75- logger .info (colored (f"LLM reasoning:\n { event .output .reasoning_content } " , "yellow" ))
76- new_steps .append (AgentThinking (thinking = event .output .reasoning_content ))
77- if event .output .get ("thinking_blocks" ):
78- for block in event .output .thinking_blocks :
79- if isinstance (block , ChatCompletionThinkingBlock ):
80- logger .info (colored (f"LLM thinking block:\n { block } " , "yellow" ))
81- new_steps .append (AgentThinking (thinking = block .content ))
82- if event .output .content :
83- logger .info (colored (f"LLM output:\n { event .output .content } " , "cyan" ))
84- new_steps .append (AgentResponse (response = event .output .content ))
85- if event .output .tool_calls :
86- logger .info (colored (f"LLM tool calls:\n { event .output .tool_calls } " , "magenta" ))
87- new_steps += [
88- self .tool_call_to_step (agent , tool_call )
89- for tool_call in event .output .tool_calls
90- ]
91- for step in new_steps :
92- yield step
93- if isinstance (step , LLMOutputParsingFailureAction ):
94- yield SetNextNode (next_node = self .name ) # loop to the same node to retry
95- break
96- if not new_steps :
97- raise FatalError ("No completions!" )
98- if (
99- self .next_node
100- and not isinstance (new_steps [- 1 ], StopStep )
101- and not any (isinstance (step , SetNextNode ) for step in new_steps )
102- ):
103- yield SetNextNode (next_node = self .next_node )
104-
105-
10633def load_config (config_name : str ) -> DictConfig :
10734 with hydra .initialize (config_path = "conf" , version_base = "1.1" ):
10835 config = hydra .compose (config_name = config_name )
@@ -113,20 +40,8 @@ def load_config(config_name: str) -> DictConfig:
11340class TapeAgentArgs (AgentArgs ):
11441 config : DictConfig = None # type: ignore
11542
116- def make_agent (self , actions : tuple [ToolSpec , ...] | None ) -> bgym .Agent :
117- if actions is None :
118- agent = hydra .utils .instantiate (self .config .agent )
119- else :
120- tapeagents_actions = [
121- ToolSpec (** tool .model_dump ()) if isinstance (tool , AgentlabToolSpec ) else tool
122- for tool in actions
123- ]
124- tools_description = "\n " .join ([action .description () for action in actions ])
125- agent = hydra .utils .instantiate (
126- self .config .agent ,
127- known_actions = tapeagents_actions ,
128- tools_description = tools_description ,
129- )
43+ def make_agent (self ) -> bgym .Agent :
44+ agent : Agent = hydra .utils .instantiate (self .config .agent )
13045 return TapeAgent (agent = agent )
13146
13247
@@ -144,62 +59,6 @@ class DictObservation(Observation):
14459 content : str
14560
14661
147- class MarkdownObservation (Observation ):
148- def llm_view (self , ** kwargs ) -> str :
149- return f"## Markdown:\n { self .content } "
150-
151- def short_view (self , max_chars : int = 100 ) -> str :
152- return self .llm_view ()[:max_chars ]
153-
154-
155- class GoalObservation (MarkdownObservation ):
156- """
157- Contains task goal
158- """
159-
160- kind : Literal ["goal_observation" ] = "goal_observation" # type: ignore
161- goal : str
162-
163- def llm_view (self , ** kwargs ) -> str :
164- return f"## Goal:\n { self .goal } "
165-
166-
167- class HTMLPage (MarkdownObservation ):
168- """
169- Contains page content
170- """
171-
172- kind : Literal ["html_page" ] = "html_page"
173- html : str
174-
175- def llm_view (self , ** kwargs ) -> str :
176- return f"## Page Content:\n { self .html } "
177-
178-
179- class AXTreePage (MarkdownObservation ):
180- """
181- Contains accessibility tree
182- """
183-
184- kind : Literal ["ax_tree_page" ] = "ax_tree_page"
185- axtree : str
186-
187- def llm_view (self , ** kwargs ) -> str :
188- return f"## Accessibility Tree:\n { self .axtree } "
189-
190-
191- class ActionResult (MarkdownObservation ):
192- """
193- Contains action result
194- """
195-
196- kind : Literal ["action_result" ] = "action_result"
197- result : str
198-
199- def llm_view (self , ** kwargs ) -> str :
200- return f"## Action Result:\n { self .result } "
201-
202-
20362class TapeAgent (bgym .Agent ):
20463 agent : Agent
20564 tape : Tape
@@ -209,50 +68,23 @@ def __init__(self, agent: Agent):
20968 self .agent = agent
21069 self .tape = Tape (steps = [])
21170
212- def obs_preprocessor (self , obs : Any ) -> list [Observation ]:
213- return obs
214-
215- def obs_to_steps (self , obs : Observation | list [Observation ] | dict ) -> list [Observation ]:
71+ def obs_preprocessor (self , obs : Observation | list [Observation ]) -> list [Observation ]:
21672 if isinstance (obs , Observation ):
21773 obs = [obs ]
218- if isinstance (obs , dict ):
219- obs_steps = []
220- if obs .get ("goal_object" ):
221- obs_steps .append (GoalObservation (goal = obs ["goal_object" ][0 ]["text" ]))
222- if obs .get ("action_result" ):
223- obs_steps .append (ActionResult (result = obs ["action_result" ]))
224- if obs .get ("pruned_html" ):
225- obs_steps .append (HTMLPage (html = obs ["pruned_html" ]))
226- if obs .get ("axtree_txt" ):
227- obs_steps .append (AXTreePage (axtree = obs ["axtree_txt" ]))
228- if obs .get ("screenshot" ):
229- if isinstance (obs ["screenshot" ], Image .Image ):
230- tmp_image_path = tempfile .mktemp (suffix = ".png" )
231- obs ["screenshot" ].save (tmp_image_path )
232- obs_steps .append (ImageObservation (image_path = tmp_image_path ))
233- else :
234- raise ValueError (f"Expected Image.Image, got { type (obs ['screenshot' ])} " )
235- if obs .get ("last_action_error" ):
236- obs_steps .append (ActionResult (result = f"Action error:\n { obs ['last_action_error' ]} " ))
237- assert len (obs_steps ) > 0 , f"Unknown dict observation, keys: { obs .keys ()} "
238- obs = obs_steps
23974 assert isinstance (obs , list ), f"Expected list of Observations, got { type (obs )} "
240- obs_view = "\n " .join ([o .short_view () for o in obs ])
241- logger .info (colored (f"Observations:\n { obs_view } " , "green" ))
75+ logger .info (f"Observations: { [type (o ).__name__ for o in obs ]} " )
24276 return obs
24377
244- def get_action (
245- self , obs : Observation | list [Observation ] | dict
246- ) -> tuple [Action , TapeAgentInfo ]:
247- self .tape += self .obs_to_steps (obs )
78+ def get_action (self , obs : Observation | list [Observation ]) -> tuple [Action , TapeAgentInfo ]:
79+ self .tape += obs # type: ignore
24880 thoughts : list [Thought ] = []
24981 action = None
25082 while not action :
25183 for event in self .agent .run (self .tape ):
25284 if not event .step :
25385 continue
25486 self .tape = self .tape .append (event .step )
255- if isinstance (event .step , Thought ) and not isinstance ( event . step , ControlFlow ) :
87+ if isinstance (event .step , Thought ):
25688 thoughts .append (event .step )
25789 logger .info (f"Thought: { event .step .llm_view ()} " )
25890 elif isinstance (event .step , Action ) and not action : # we use first action only
@@ -262,11 +94,10 @@ def get_action(
26294 # there could be control flow steps for switching nodes and if clauses
26395 logger .info (f"Other step: { type (event .step )} " )
26496 logger .info (f"Tape after run: ({ len (self .tape )} ) { [type (s ).__name__ for s in self .tape ]} " )
265- think_str = "\n " .join ([t .llm_view () for t in thoughts ])
266- return (action , {"thoughts" : thoughts , "think" : think_str })
97+ return (action , TapeAgentInfo (thoughts = thoughts ))
26798
26899 @property
269100 def final_tape (self ) -> Tape :
270101 truncated = not any ([isinstance (s , StopStep ) for s in self .tape .steps ])
271102 self .tape .metadata = ExtendedMetadata (author = self .agent .name , truncated = truncated )
272- return self .tape
103+ return self .tape
0 commit comments