11import json
22import logging
33import pprint
4+ import time
45from dataclasses import dataclass
56from functools import partial
6- from typing import Callable
7+ from typing import Callable , Literal
78
8- from litellm import completion_with_retries
9+ from litellm import completion
910from litellm .types .utils import ChatCompletionMessageToolCall , Message , ModelResponse
1011from PIL import Image
1112from termcolor import colored
1718
1819logger = logging .getLogger (__name__ )
1920
20- @dataclass
21- class Observation :
22- data : dict # expected keys: goal_object, pruned_html, axtree_txt, screenshot, last_action_error, action_result
23-
24- def to_messages (self ) -> list [dict ]:
25- """
26- Convert the observation dictionary into a list of chat messages for Lite LLM
27- """
28- messages = []
29- tool_call_id = self .data .get ("tool_call_id" )
30- if self .data .get ("goal_object" ) and not tool_call_id : # its a first observation when there are no tool_call_id, so include goal
31- goal = self .data ["goal_object" ][0 ]["text" ]
32- messages .append ({
33- "role" : "user" ,
34- "content" : f"## Goal:\n { goal } "
35- })
36- text_obs = []
37- if self .data .get ("action_result" ):
38- result = self .data ["action_result" ]
39- text_obs .append (f"Action Result:\n { result } " )
40- if self .data .get ("pruned_html" ):
41- html = self .data ["pruned_html" ]
42- text_obs .append (f"Pruned HTML:\n { html } " )
43- if self .data .get ("axtree_txt" ):
44- axtree = self .data ["axtree_txt" ]
45- text_obs .append (f"Accessibility Tree:\n { axtree } " )
46- if self .data .get ("last_action_error" ):
47- error = self .data ['last_action_error' ]
48- text_obs .append (f"Action Error:\n { error } " )
49- if text_obs :
50- if tool_call_id :
51- message = {
52- "role" : "tool" ,
53- "tool_call_id" : tool_call_id ,
54- "content" : "\n \n " .join (text_obs ),
55- }
56- else :
57- message = {
58- "role" : "user" ,
59- "content" : "\n \n " .join (text_obs ),
60- }
61- messages .append (message )
62- if self .data .get ("screenshot" ):
63- if isinstance (self .data ["screenshot" ], Image .Image ):
64- image_content_url = image_to_png_base64_url (self .data ["screenshot" ])
65- messages .append ({
66- "role" : "user" ,
67- "content" : [{"type" : "image_url" , "image_url" : {"url" : image_content_url }}],
68- })
69- else :
70- raise ValueError (f"Expected Image.Image, got { type (self .data ['screenshot' ])} " )
71- return messages
7221
73- @dataclass
74- class LLMOutput :
75- """
76- LiteLLM output message containing all the llm response details, suitable for putting back into prompt to reuse KV cache
77- """
78- message : Message
79- def to_messages (self ) -> list [Message ]:
80- return [self .message ]
81-
82- @dataclass
83- class SystemMessage :
84- message : str
85- def to_messages (self ) -> list [dict ]:
86- return [{"role" : "system" , "content" : self .message }]
22+ class LLMArgs (BaseModelArgs ):
23+ reasoning_effort : Literal ["minimal" , "low" , "medium" , "high" ] = "low"
24+ num_retries : int = 3
8725
88- @dataclass
89- class UserMessage :
90- message : str
91- def to_messages (self ) -> list [dict ]:
92- return [{"role" : "user" , "content" : self .message }]
26+ def make_model (self ) -> Callable :
27+ return partial (
28+ completion ,
29+ model = self .model_name ,
30+ temperature = self .temperature ,
31+ max_tokens = self .max_total_tokens ,
32+ max_completion_tokens = self .max_new_tokens ,
33+ reasoning_effort = self .reasoning_effort ,
34+ num_retries = self .num_retries ,
35+ tool_choice = "auto" ,
36+ parallel_tool_calls = False ,
37+ )
9338
94- Step = LLMOutput | Observation | SystemMessage | UserMessage
9539
9640@dataclass
9741class AgentConfig :
@@ -112,68 +56,90 @@ class AgentConfig:
112562. Evaluate action success, explain impact on task and next steps.
113573. If you see any errors in the last observation, think about it. If there is no error, just move on.
114584. List next steps to move towards the goal and propose next immediate action.
115- Then produce the function call that performs the proposed action. If the task is complete, produce the final step.
59+ Then produce the single function call that performs the proposed action. If the task is complete, produce the final step.
11660"""
11761
118- class LLMArgs (BaseModelArgs ):
119- reasoning_effort : str = "low"
120-
121- def make_model (self ) -> Callable :
122- return partial (
123- completion_with_retries ,
124- model = self .model_name ,
125- temperature = self .temperature ,
126- max_tokens = self .max_total_tokens ,
127- max_completion_tokens = self .max_new_tokens ,
128- reasoning_effort = self .reasoning_effort ,
129- )
13062
13163class ReactToolCallAgent :
132- def __init__ (self , action_set : ToolsActionSet , llm : Callable , config : AgentConfig ):
64+ def __init__ (
65+ self , action_set : ToolsActionSet , llm : Callable [..., ModelResponse ], config : AgentConfig
66+ ):
13367 self .action_set = action_set
134- self .history : list [Step ] = [SystemMessage ( message = config .system_prompt ) ]
68+ self .history : list [dict | Message ] = [{ "role" : "system" , "content" : config .system_prompt } ]
13569 self .llm = llm
13670 self .config = config
13771 self .last_tool_call_id : str = ""
13872
13973 def obs_preprocessor (self , obs : dict ) -> dict :
140- if not self .config .use_html :
141- obs .pop ("pruned_html" , None )
142- if not self .config .use_axtree :
143- obs .pop ("axtree_txt" , None )
144- if not self .config .use_screenshot :
145- obs .pop ("screenshot" , None )
146- if self .last_tool_call_id :
147- # add tool_call_id to obs for linking observation to the last executed action
148- obs ["tool_call_id" ] = self .last_tool_call_id
14974 return obs
15075
76+ def obs_to_messages (self , obs : dict ) -> list [dict ]:
77+ """
78+ Convert the observation dictionary into a list of chat messages for Lite LLM
79+ """
80+ messages = []
81+ if obs .get ("goal_object" ) and not self .last_tool_call_id :
82+ # its a first observation when there are no tool_call_id, so include goal
83+ goal = obs ["goal_object" ][0 ]["text" ]
84+ messages .append ({"role" : "user" , "content" : f"## Goal:\n { goal } " })
85+ text_obs = []
86+ if result := obs .get ("action_result" ):
87+ text_obs .append (f"## Action Result:\n { result } " )
88+ if error := obs .get ("last_action_error" ):
89+ text_obs .append (f"## Action Error:\n { error } " )
90+ if self .config .use_html and (html := obs .get ("pruned_html" )):
91+ text_obs .append (f"## HTML:\n { html } " )
92+ if self .config .use_axtree and (axtree := obs .get ("axtree_txt" )):
93+ text_obs .append (f"## Accessibility Tree:\n { axtree } " )
94+ content = "\n \n " .join (text_obs )
95+ if content :
96+ if self .last_tool_call_id :
97+ message = {
98+ "role" : "tool" ,
99+ "tool_call_id" : self .last_tool_call_id ,
100+ "content" : content ,
101+ }
102+ else :
103+ message = {"role" : "user" , "content" : content }
104+ messages .append (message )
105+ if self .config .use_screenshot and (scr := obs .get ("screenshot" )):
106+ if isinstance (scr , Image .Image ):
107+ image_content = [
108+ {"type" : "image_url" , "image_url" : {"url" : image_to_png_base64_url (scr )}}
109+ ]
110+ messages .append ({"role" : "user" , "content" : image_content })
111+ else :
112+ raise ValueError (
113+ f"Expected Image.Image in screenshot obs, got { type (obs ['screenshot' ])} "
114+ )
115+ return messages
116+
151117 def get_action (self , obs : dict ) -> tuple [ToolCallAction , dict ]:
152- prev_actions = [step for step in self .history if isinstance (step , LLMOutput )]
153- if len (prev_actions ) >= self .config .max_actions :
118+ actions_count = len (
119+ [msg for msg in self .history if isinstance (msg , Message ) and msg .tool_calls ]
120+ )
121+ if actions_count >= self .config .max_actions :
154122 logger .warning ("Max actions reached, stopping agent." )
155- stop_action = ToolCallAction (id = "stop" , function = FunctionCall (name = "final_step" , arguments = {}))
123+ stop_action = ToolCallAction (
124+ id = "stop" , function = FunctionCall (name = "final_step" , arguments = {})
125+ )
156126 return stop_action , {}
157- self .history .append (Observation (data = obs ))
158- steps = self .history + [UserMessage (message = self .config .guidance )]
159- messages = [m for step in steps for m in step .to_messages ()]
127+ self .history += self .obs_to_messages (self .obs_preprocessor (obs ))
160128 tools = [tool .model_dump () for tool in self .action_set .actions ]
129+ messages = self .history + [{"role" : "user" , "content" : self .config .guidance }]
130+
161131 try :
162132 logger .info (colored (f"Prompt:\n { pprint .pformat (messages , width = 120 )} " , "blue" ))
163- response : ModelResponse = self .llm (
164- tools = tools ,
165- messages = messages ,
166- num_retries = self .config .max_retry ,
167- )
168- message = response .choices [0 ].message # type: ignore
133+ response = self .llm (tools = tools , messages = messages )
134+ message = response .choices [0 ].message # type: ignore
169135 except Exception as e :
170136 logger .exception (f"Error getting LLM response: { e } . Prompt: { messages } " )
171137 raise e
172138 logger .info (colored (f"LLM response:\n { pprint .pformat (message , width = 120 )} " , "green" ))
173- self .history .append (LLMOutput (message = message ))
139+
140+ self .history .append (message )
174141 thoughts = self .thoughts_from_message (message )
175142 action = self .action_from_message (message )
176-
177143 return action , {"think" : thoughts }
178144
179145 def thoughts_from_message (self , message ) -> str :
@@ -187,7 +153,7 @@ def thoughts_from_message(self, message) -> str:
187153 logger .info (colored (f"LLM thinking block:\n { thinking } " , "yellow" ))
188154 thoughts .append (thinking )
189155 if message .content :
190- logger .info (colored (f"LLM output:\n { message .content } " , "cyan" ))
156+ logger .info (colored (f"LLM text output:\n { message .content } " , "cyan" ))
191157 thoughts .append (message .content )
192158 return "\n \n " .join (thoughts )
193159
@@ -199,27 +165,27 @@ def action_from_message(self, message) -> ToolCallAction:
199165 assert isinstance (tool_call .function .name , str )
200166 try :
201167 args = json .loads (tool_call .function .arguments )
202- action = ToolCallAction (
203- id = tool_call .id ,
204- function = FunctionCall (name = tool_call .function .name , arguments = args )
205- )
206168 except json .JSONDecodeError as e :
207- logger .exception (f"Error in json parsing of tool call arguments, { e } : { tool_call .function .arguments } " )
169+ logger .exception (
170+ f"Error in json parsing of tool call arguments, { e } : { tool_call .function .arguments } "
171+ )
208172 raise e
209-
173+ action = ToolCallAction (
174+ id = tool_call .id , function = FunctionCall (name = tool_call .function .name , arguments = args )
175+ )
210176 self .last_tool_call_id = action .id
177+ logger .info (f"Parsed tool call action: { action } " )
211178 else :
212179 raise ValueError (f"No tool call found in LLM response: { message } " )
213180 return action
214-
181+
215182
216183@dataclass
217184class ReactToolCallAgentArgs (AgentArgs ):
218- llm_args : LLMArgs = None # type: ignore
219- config : AgentConfig = None # type: ignore
185+ llm_args : LLMArgs | None = None
186+ config : AgentConfig | None = None
220187
221188 def make_agent (self , actions : list [ToolSpec ]) -> ReactToolCallAgent :
222189 llm = self .llm_args .make_model ()
223190 action_set = ToolsActionSet (actions = actions )
224191 return ReactToolCallAgent (action_set = action_set , llm = llm , config = self .config )
225-
0 commit comments