4
4
# @ProjectName: browser-use-webui
5
5
# @FileName: custom_agent.py
6
6
7
- import asyncio
8
- import base64
9
- import io
10
7
import json
11
8
import logging
12
- import os
13
9
import pdb
14
- import textwrap
15
- import time
16
- import uuid
17
- from io import BytesIO
18
- from pathlib import Path
19
- from typing import Any , Optional , Type , TypeVar
20
-
21
- from dotenv import load_dotenv
22
- from langchain_core .language_models .chat_models import BaseChatModel
23
- from langchain_core .messages import (
24
- BaseMessage ,
25
- SystemMessage ,
26
- )
27
- from openai import RateLimitError
28
- from PIL import Image , ImageDraw , ImageFont
29
- from pydantic import BaseModel , ValidationError
10
+ import traceback
11
+ from typing import Optional , Type
30
12
31
- from browser_use .agent .message_manager .service import MessageManager
32
- from browser_use .agent .prompts import AgentMessagePrompt , SystemPrompt
13
+ from browser_use .agent .prompts import SystemPrompt
33
14
from browser_use .agent .service import Agent
34
15
from browser_use .agent .views import (
35
16
ActionResult ,
36
- AgentError ,
37
- AgentHistory ,
38
17
AgentHistoryList ,
39
18
AgentOutput ,
40
- AgentStepInfo ,
41
19
)
42
20
from browser_use .browser .browser import Browser
43
21
from browser_use .browser .context import BrowserContext
44
- from browser_use .browser .views import BrowserState , BrowserStateHistory
45
- from browser_use .controller .registry .views import ActionModel
46
22
from browser_use .controller .service import Controller
47
- from browser_use .dom .history_tree_processor .service import (
48
- DOMHistoryElement ,
49
- HistoryTreeProcessor ,
50
- )
51
- from browser_use .telemetry .service import ProductTelemetry
52
23
from browser_use .telemetry .views import (
53
24
AgentEndTelemetryEvent ,
54
25
AgentRunTelemetryEvent ,
55
26
AgentStepErrorTelemetryEvent ,
56
27
)
57
28
from browser_use .utils import time_execution_async
29
+ from langchain_core .language_models .chat_models import BaseChatModel
30
+ from langchain_core .messages import (
31
+ BaseMessage ,
32
+ )
58
33
59
- from .custom_views import CustomAgentOutput , CustomAgentStepInfo
60
34
from .custom_massage_manager import CustomMassageManager
35
+ from .custom_views import CustomAgentOutput , CustomAgentStepInfo
61
36
62
37
logger = logging .getLogger (__name__ )
63
38
64
39
65
40
class CustomAgent (Agent ):
66
-
67
41
def __init__ (
68
42
self ,
69
43
task : str ,
70
44
llm : BaseChatModel ,
71
- add_infos : str = '' ,
45
+ add_infos : str = "" ,
72
46
browser : Browser | None = None ,
73
47
browser_context : BrowserContext | None = None ,
74
48
controller : Controller = Controller (),
@@ -80,23 +54,39 @@ def __init__(
80
54
max_input_tokens : int = 128000 ,
81
55
validate_output : bool = False ,
82
56
include_attributes : list [str ] = [
83
- ' title' ,
84
- ' type' ,
85
- ' name' ,
86
- ' role' ,
87
- ' tabindex' ,
88
- ' aria-label' ,
89
- ' placeholder' ,
90
- ' value' ,
91
- ' alt' ,
92
- ' aria-expanded' ,
57
+ " title" ,
58
+ " type" ,
59
+ " name" ,
60
+ " role" ,
61
+ " tabindex" ,
62
+ " aria-label" ,
63
+ " placeholder" ,
64
+ " value" ,
65
+ " alt" ,
66
+ " aria-expanded" ,
93
67
],
94
68
max_error_length : int = 400 ,
95
69
max_actions_per_step : int = 10 ,
70
+ tool_call_in_content : bool = True ,
96
71
):
97
- super ().__init__ (task , llm , browser , browser_context , controller , use_vision , save_conversation_path ,
98
- max_failures , retry_delay , system_prompt_class , max_input_tokens , validate_output ,
99
- include_attributes , max_error_length , max_actions_per_step )
72
+ super ().__init__ (
73
+ task = task ,
74
+ llm = llm ,
75
+ browser = browser ,
76
+ browser_context = browser_context ,
77
+ controller = controller ,
78
+ use_vision = use_vision ,
79
+ save_conversation_path = save_conversation_path ,
80
+ max_failures = max_failures ,
81
+ retry_delay = retry_delay ,
82
+ system_prompt_class = system_prompt_class ,
83
+ max_input_tokens = max_input_tokens ,
84
+ validate_output = validate_output ,
85
+ include_attributes = include_attributes ,
86
+ max_error_length = max_error_length ,
87
+ max_actions_per_step = max_actions_per_step ,
88
+ tool_call_in_content = tool_call_in_content ,
89
+ )
100
90
self .add_infos = add_infos
101
91
self .message_manager = CustomMassageManager (
102
92
llm = self .llm ,
@@ -107,6 +97,7 @@ def __init__(
107
97
include_attributes = self .include_attributes ,
108
98
max_error_length = self .max_error_length ,
109
99
max_actions_per_step = self .max_actions_per_step ,
100
+ tool_call_in_content = tool_call_in_content ,
110
101
)
111
102
112
103
def _setup_action_models (self ) -> None :
@@ -118,24 +109,26 @@ def _setup_action_models(self) -> None:
118
109
119
110
def _log_response (self , response : CustomAgentOutput ) -> None :
120
111
"""Log the model's response"""
121
- if ' Success' in response .current_state .prev_action_evaluation :
122
- emoji = '✅'
123
- elif ' Failed' in response .current_state .prev_action_evaluation :
124
- emoji = '❌'
112
+ if " Success" in response .current_state .prev_action_evaluation :
113
+ emoji = "✅"
114
+ elif " Failed" in response .current_state .prev_action_evaluation :
115
+ emoji = "❌"
125
116
else :
126
- emoji = '🤷'
117
+ emoji = "🤷"
127
118
128
- logger .info (f' { emoji } Eval: { response .current_state .prev_action_evaluation } ' )
129
- logger .info (f' 🧠 New Memory: { response .current_state .important_contents } ' )
130
- logger .info (f' ⏳ Task Progress: { response .current_state .completed_contents } ' )
131
- logger .info (f' 🤔 Thought: { response .current_state .thought } ' )
132
- logger .info (f' 🎯 Summary: { response .current_state .summary } ' )
119
+ logger .info (f" { emoji } Eval: { response .current_state .prev_action_evaluation } " )
120
+ logger .info (f" 🧠 New Memory: { response .current_state .important_contents } " )
121
+ logger .info (f" ⏳ Task Progress: { response .current_state .completed_contents } " )
122
+ logger .info (f" 🤔 Thought: { response .current_state .thought } " )
123
+ logger .info (f" 🎯 Summary: { response .current_state .summary } " )
133
124
for i , action in enumerate (response .action ):
134
125
logger .info (
135
- f' 🛠️ Action { i + 1 } /{ len (response .action )} : { action .model_dump_json (exclude_unset = True )} '
126
+ f" 🛠️ Action { i + 1 } /{ len (response .action )} : { action .model_dump_json (exclude_unset = True )} "
136
127
)
137
128
138
- def update_step_info (self , model_output : CustomAgentOutput , step_info : CustomAgentStepInfo = None ):
129
+ def update_step_info (
130
+ self , model_output : CustomAgentOutput , step_info : CustomAgentStepInfo = None
131
+ ):
139
132
"""
140
133
update step info
141
134
"""
@@ -144,31 +137,54 @@ def update_step_info(self, model_output: CustomAgentOutput, step_info: CustomAge
144
137
145
138
step_info .step_number += 1
146
139
important_contents = model_output .current_state .important_contents
147
- if important_contents and 'None' not in important_contents and important_contents not in step_info .memory :
148
- step_info .memory += important_contents + '\n '
140
+ if (
141
+ important_contents
142
+ and "None" not in important_contents
143
+ and important_contents not in step_info .memory
144
+ ):
145
+ step_info .memory += important_contents + "\n "
149
146
150
147
completed_contents = model_output .current_state .completed_contents
151
- if completed_contents and ' None' not in completed_contents :
148
+ if completed_contents and " None" not in completed_contents :
152
149
step_info .task_progress = completed_contents
153
150
154
- @time_execution_async (' --get_next_action' )
151
+ @time_execution_async (" --get_next_action" )
155
152
async def get_next_action (self , input_messages : list [BaseMessage ]) -> AgentOutput :
156
153
"""Get next action from LLM based on current state"""
154
+ try :
155
+ structured_llm = self .llm .with_structured_output (self .AgentOutput , include_raw = True )
156
+ response : dict [str , Any ] = await structured_llm .ainvoke (input_messages ) # type: ignore
157
157
158
- ret = self .llm .invoke (input_messages )
159
- parsed_json = json .loads (ret .content .replace ('```json' , '' ).replace ("```" , "" ))
160
- parsed : AgentOutput = self .AgentOutput (** parsed_json )
161
- # cut the number of actions to max_actions_per_step
162
- parsed .action = parsed .action [: self .max_actions_per_step ]
163
- self ._log_response (parsed )
164
- self .n_steps += 1
158
+ parsed : AgentOutput = response ['parsed' ]
159
+ # cut the number of actions to max_actions_per_step
160
+ parsed .action = parsed .action [: self .max_actions_per_step ]
161
+ self ._log_response (parsed )
162
+ self .n_steps += 1
165
163
166
- return parsed
164
+ return parsed
165
+ except Exception as e :
166
+ # If something goes wrong, try to invoke the LLM again without structured output,
167
+ # and Manually parse the response. Temporarily solution for DeepSeek
168
+ ret = self .llm .invoke (input_messages )
169
+ if isinstance (ret .content , list ):
170
+ parsed_json = json .loads (ret .content [0 ].replace ("```json" , "" ).replace ("```" , "" ))
171
+ else :
172
+ parsed_json = json .loads (ret .content .replace ("```json" , "" ).replace ("```" , "" ))
173
+ parsed : AgentOutput = self .AgentOutput (** parsed_json )
174
+ if parsed is None :
175
+ raise ValueError (f'Could not parse response.' )
176
+
177
+ # cut the number of actions to max_actions_per_step
178
+ parsed .action = parsed .action [: self .max_actions_per_step ]
179
+ self ._log_response (parsed )
180
+ self .n_steps += 1
167
181
168
- @time_execution_async ('--step' )
182
+ return parsed
183
+
184
+ @time_execution_async ("--step" )
169
185
async def step (self , step_info : Optional [CustomAgentStepInfo ] = None ) -> None :
170
186
"""Execute one step of the task"""
171
- logger .info (f' \n 📍 Step { self .n_steps } ' )
187
+ logger .info (f" \n 📍 Step { self .n_steps } " )
172
188
state = None
173
189
model_output = None
174
190
result : list [ActionResult ] = []
@@ -179,7 +195,7 @@ async def step(self, step_info: Optional[CustomAgentStepInfo] = None) -> None:
179
195
input_messages = self .message_manager .get_messages ()
180
196
model_output = await self .get_next_action (input_messages )
181
197
self .update_step_info (model_output , step_info )
182
- logger .info (f' 🧠 All Memory: { step_info .memory } ' )
198
+ logger .info (f" 🧠 All Memory: { step_info .memory } " )
183
199
self ._save_conversation (input_messages , model_output )
184
200
self .message_manager ._remove_last_state_message () # we dont want the whole state in the chat history
185
201
self .message_manager .add_model_output (model_output )
@@ -190,7 +206,7 @@ async def step(self, step_info: Optional[CustomAgentStepInfo] = None) -> None:
190
206
self ._last_result = result
191
207
192
208
if len (result ) > 0 and result [- 1 ].is_done :
193
- logger .info (f' 📄 Result: { result [- 1 ].extracted_content } ' )
209
+ logger .info (f" 📄 Result: { result [- 1 ].extracted_content } " )
194
210
195
211
self .consecutive_failures = 0
196
212
@@ -215,7 +231,7 @@ async def step(self, step_info: Optional[CustomAgentStepInfo] = None) -> None:
215
231
async def run (self , max_steps : int = 100 ) -> AgentHistoryList :
216
232
"""Execute the task with maximum number of steps"""
217
233
try :
218
- logger .info (f' 🚀 Starting task: { self .task } ' )
234
+ logger .info (f" 🚀 Starting task: { self .task } " )
219
235
220
236
self .telemetry .capture (
221
237
AgentRunTelemetryEvent (
@@ -224,13 +240,14 @@ async def run(self, max_steps: int = 100) -> AgentHistoryList:
224
240
)
225
241
)
226
242
227
- step_info = CustomAgentStepInfo (task = self .task ,
228
- add_infos = self .add_infos ,
229
- step_number = 1 ,
230
- max_steps = max_steps ,
231
- memory = '' ,
232
- task_progress = ''
233
- )
243
+ step_info = CustomAgentStepInfo (
244
+ task = self .task ,
245
+ add_infos = self .add_infos ,
246
+ step_number = 1 ,
247
+ max_steps = max_steps ,
248
+ memory = "" ,
249
+ task_progress = "" ,
250
+ )
234
251
235
252
for step in range (max_steps ):
236
253
if self ._too_many_failures ():
@@ -245,10 +262,10 @@ async def run(self, max_steps: int = 100) -> AgentHistoryList:
245
262
if not await self ._validate_output ():
246
263
continue
247
264
248
- logger .info (' ✅ Task completed successfully' )
265
+ logger .info (" ✅ Task completed successfully" )
249
266
break
250
267
else :
251
- logger .info (' ❌ Failed to complete task in maximum steps' )
268
+ logger .info (" ❌ Failed to complete task in maximum steps" )
252
269
253
270
return self .history
254
271
0 commit comments