4
4
# @ProjectName: browser-use-webui
5
5
# @FileName: custom_agent.py
6
6
7
- import asyncio
8
- import base64
9
- import io
10
7
import json
11
8
import logging
12
- import os
13
- import pdb
14
- import textwrap
15
- import time
16
- import uuid
17
- from io import BytesIO
18
- from pathlib import Path
19
- from typing import Any , Optional , Type , TypeVar
20
-
21
- from dotenv import load_dotenv
22
- from langchain_core .language_models .chat_models import BaseChatModel
23
- from langchain_core .messages import (
24
- BaseMessage ,
25
- SystemMessage ,
26
- )
27
- from openai import RateLimitError
28
- from PIL import Image , ImageDraw , ImageFont
29
- from pydantic import BaseModel , ValidationError
9
+ from typing import Optional , Type
30
10
31
- from browser_use .agent .message_manager .service import MessageManager
32
- from browser_use .agent .prompts import AgentMessagePrompt , SystemPrompt
11
+ from browser_use .agent .prompts import SystemPrompt
33
12
from browser_use .agent .service import Agent
34
13
from browser_use .agent .views import (
35
14
ActionResult ,
36
- AgentError ,
37
- AgentHistory ,
38
15
AgentHistoryList ,
39
16
AgentOutput ,
40
- AgentStepInfo ,
41
17
)
42
18
from browser_use .browser .browser import Browser
43
19
from browser_use .browser .context import BrowserContext
44
- from browser_use .browser .views import BrowserState , BrowserStateHistory
45
- from browser_use .controller .registry .views import ActionModel
46
20
from browser_use .controller .service import Controller
47
- from browser_use .dom .history_tree_processor .service import (
48
- DOMHistoryElement ,
49
- HistoryTreeProcessor ,
50
- )
51
- from browser_use .telemetry .service import ProductTelemetry
52
21
from browser_use .telemetry .views import (
53
22
AgentEndTelemetryEvent ,
54
23
AgentRunTelemetryEvent ,
55
24
AgentStepErrorTelemetryEvent ,
56
25
)
57
26
from browser_use .utils import time_execution_async
27
+ from langchain_core .language_models .chat_models import BaseChatModel
28
+ from langchain_core .messages import (
29
+ BaseMessage ,
30
+ )
58
31
59
- from .custom_views import CustomAgentOutput , CustomAgentStepInfo
60
32
from .custom_massage_manager import CustomMassageManager
33
+ from .custom_views import CustomAgentOutput , CustomAgentStepInfo
61
34
62
35
logger = logging .getLogger (__name__ )
63
36
64
37
65
38
class CustomAgent (Agent ):
66
-
67
39
def __init__ (
68
- self ,
69
- task : str ,
70
- llm : BaseChatModel ,
71
- add_infos : str = '' ,
72
- browser : Browser | None = None ,
73
- browser_context : BrowserContext | None = None ,
74
- controller : Controller = Controller (),
75
- use_vision : bool = True ,
76
- save_conversation_path : Optional [str ] = None ,
77
- max_failures : int = 5 ,
78
- retry_delay : int = 10 ,
79
- system_prompt_class : Type [SystemPrompt ] = SystemPrompt ,
80
- max_input_tokens : int = 128000 ,
81
- validate_output : bool = False ,
82
- include_attributes : list [str ] = [
83
- ' title' ,
84
- ' type' ,
85
- ' name' ,
86
- ' role' ,
87
- ' tabindex' ,
88
- ' aria-label' ,
89
- ' placeholder' ,
90
- ' value' ,
91
- ' alt' ,
92
- ' aria-expanded' ,
93
- ],
94
- max_error_length : int = 400 ,
95
- max_actions_per_step : int = 10 ,
40
+ self ,
41
+ task : str ,
42
+ llm : BaseChatModel ,
43
+ add_infos : str = "" ,
44
+ browser : Browser | None = None ,
45
+ browser_context : BrowserContext | None = None ,
46
+ controller : Controller = Controller (),
47
+ use_vision : bool = True ,
48
+ save_conversation_path : Optional [str ] = None ,
49
+ max_failures : int = 5 ,
50
+ retry_delay : int = 10 ,
51
+ system_prompt_class : Type [SystemPrompt ] = SystemPrompt ,
52
+ max_input_tokens : int = 128000 ,
53
+ validate_output : bool = False ,
54
+ include_attributes : list [str ] = [
55
+ " title" ,
56
+ " type" ,
57
+ " name" ,
58
+ " role" ,
59
+ " tabindex" ,
60
+ " aria-label" ,
61
+ " placeholder" ,
62
+ " value" ,
63
+ " alt" ,
64
+ " aria-expanded" ,
65
+ ],
66
+ max_error_length : int = 400 ,
67
+ max_actions_per_step : int = 10 ,
96
68
):
97
- super ().__init__ (task , llm , browser , browser_context , controller , use_vision , save_conversation_path ,
98
- max_failures , retry_delay , system_prompt_class , max_input_tokens , validate_output ,
99
- include_attributes , max_error_length , max_actions_per_step )
69
+ super ().__init__ (
70
+ task ,
71
+ llm ,
72
+ browser ,
73
+ browser_context ,
74
+ controller ,
75
+ use_vision ,
76
+ save_conversation_path ,
77
+ max_failures ,
78
+ retry_delay ,
79
+ system_prompt_class ,
80
+ max_input_tokens ,
81
+ validate_output ,
82
+ include_attributes ,
83
+ max_error_length ,
84
+ max_actions_per_step ,
85
+ )
100
86
self .add_infos = add_infos
101
87
self .message_manager = CustomMassageManager (
102
88
llm = self .llm ,
@@ -118,24 +104,26 @@ def _setup_action_models(self) -> None:
118
104
119
105
def _log_response (self , response : CustomAgentOutput ) -> None :
120
106
"""Log the model's response"""
121
- if ' Success' in response .current_state .prev_action_evaluation :
122
- emoji = '✅'
123
- elif ' Failed' in response .current_state .prev_action_evaluation :
124
- emoji = '❌'
107
+ if " Success" in response .current_state .prev_action_evaluation :
108
+ emoji = "✅"
109
+ elif " Failed" in response .current_state .prev_action_evaluation :
110
+ emoji = "❌"
125
111
else :
126
- emoji = '🤷'
112
+ emoji = "🤷"
127
113
128
- logger .info (f' { emoji } Eval: { response .current_state .prev_action_evaluation } ' )
129
- logger .info (f' 🧠 New Memory: { response .current_state .important_contents } ' )
130
- logger .info (f' ⏳ Task Progress: { response .current_state .completed_contents } ' )
131
- logger .info (f' 🤔 Thought: { response .current_state .thought } ' )
132
- logger .info (f' 🎯 Summary: { response .current_state .summary } ' )
114
+ logger .info (f" { emoji } Eval: { response .current_state .prev_action_evaluation } " )
115
+ logger .info (f" 🧠 New Memory: { response .current_state .important_contents } " )
116
+ logger .info (f" ⏳ Task Progress: { response .current_state .completed_contents } " )
117
+ logger .info (f" 🤔 Thought: { response .current_state .thought } " )
118
+ logger .info (f" 🎯 Summary: { response .current_state .summary } " )
133
119
for i , action in enumerate (response .action ):
134
120
logger .info (
135
- f' 🛠️ Action { i + 1 } /{ len (response .action )} : { action .model_dump_json (exclude_unset = True )} '
121
+ f" 🛠️ Action { i + 1 } /{ len (response .action )} : { action .model_dump_json (exclude_unset = True )} "
136
122
)
137
123
138
- def update_step_info (self , model_output : CustomAgentOutput , step_info : CustomAgentStepInfo = None ):
124
+ def update_step_info (
125
+ self , model_output : CustomAgentOutput , step_info : CustomAgentStepInfo = None
126
+ ):
139
127
"""
140
128
update step info
141
129
"""
@@ -144,19 +132,23 @@ def update_step_info(self, model_output: CustomAgentOutput, step_info: CustomAge
144
132
145
133
step_info .step_number += 1
146
134
important_contents = model_output .current_state .important_contents
147
- if important_contents and 'None' not in important_contents and important_contents not in step_info .memory :
148
- step_info .memory += important_contents + '\n '
135
+ if (
136
+ important_contents
137
+ and "None" not in important_contents
138
+ and important_contents not in step_info .memory
139
+ ):
140
+ step_info .memory += important_contents + "\n "
149
141
150
142
completed_contents = model_output .current_state .completed_contents
151
- if completed_contents and ' None' not in completed_contents :
143
+ if completed_contents and " None" not in completed_contents :
152
144
step_info .task_progress = completed_contents
153
145
154
- @time_execution_async (' --get_next_action' )
146
+ @time_execution_async (" --get_next_action" )
155
147
async def get_next_action (self , input_messages : list [BaseMessage ]) -> AgentOutput :
156
148
"""Get next action from LLM based on current state"""
157
149
158
150
ret = self .llm .invoke (input_messages )
159
- parsed_json = json .loads (ret .content .replace (' ```json' , '' ).replace ("```" , "" ))
151
+ parsed_json = json .loads (ret .content .replace (" ```json" , "" ).replace ("```" , "" ))
160
152
parsed : AgentOutput = self .AgentOutput (** parsed_json )
161
153
# cut the number of actions to max_actions_per_step
162
154
parsed .action = parsed .action [: self .max_actions_per_step ]
@@ -165,10 +157,10 @@ async def get_next_action(self, input_messages: list[BaseMessage]) -> AgentOutpu
165
157
166
158
return parsed
167
159
168
- @time_execution_async (' --step' )
160
+ @time_execution_async (" --step" )
169
161
async def step (self , step_info : Optional [CustomAgentStepInfo ] = None ) -> None :
170
162
"""Execute one step of the task"""
171
- logger .info (f' \n 📍 Step { self .n_steps } ' )
163
+ logger .info (f" \n 📍 Step { self .n_steps } " )
172
164
state = None
173
165
model_output = None
174
166
result : list [ActionResult ] = []
@@ -179,7 +171,7 @@ async def step(self, step_info: Optional[CustomAgentStepInfo] = None) -> None:
179
171
input_messages = self .message_manager .get_messages ()
180
172
model_output = await self .get_next_action (input_messages )
181
173
self .update_step_info (model_output , step_info )
182
- logger .info (f' 🧠 All Memory: { step_info .memory } ' )
174
+ logger .info (f" 🧠 All Memory: { step_info .memory } " )
183
175
self ._save_conversation (input_messages , model_output )
184
176
self .message_manager ._remove_last_state_message () # we dont want the whole state in the chat history
185
177
self .message_manager .add_model_output (model_output )
@@ -190,7 +182,7 @@ async def step(self, step_info: Optional[CustomAgentStepInfo] = None) -> None:
190
182
self ._last_result = result
191
183
192
184
if len (result ) > 0 and result [- 1 ].is_done :
193
- logger .info (f' 📄 Result: { result [- 1 ].extracted_content } ' )
185
+ logger .info (f" 📄 Result: { result [- 1 ].extracted_content } " )
194
186
195
187
self .consecutive_failures = 0
196
188
@@ -215,7 +207,7 @@ async def step(self, step_info: Optional[CustomAgentStepInfo] = None) -> None:
215
207
async def run (self , max_steps : int = 100 ) -> AgentHistoryList :
216
208
"""Execute the task with maximum number of steps"""
217
209
try :
218
- logger .info (f' 🚀 Starting task: { self .task } ' )
210
+ logger .info (f" 🚀 Starting task: { self .task } " )
219
211
220
212
self .telemetry .capture (
221
213
AgentRunTelemetryEvent (
@@ -224,13 +216,14 @@ async def run(self, max_steps: int = 100) -> AgentHistoryList:
224
216
)
225
217
)
226
218
227
- step_info = CustomAgentStepInfo (task = self .task ,
228
- add_infos = self .add_infos ,
229
- step_number = 1 ,
230
- max_steps = max_steps ,
231
- memory = '' ,
232
- task_progress = ''
233
- )
219
+ step_info = CustomAgentStepInfo (
220
+ task = self .task ,
221
+ add_infos = self .add_infos ,
222
+ step_number = 1 ,
223
+ max_steps = max_steps ,
224
+ memory = "" ,
225
+ task_progress = "" ,
226
+ )
234
227
235
228
for step in range (max_steps ):
236
229
if self ._too_many_failures ():
@@ -240,15 +233,15 @@ async def run(self, max_steps: int = 100) -> AgentHistoryList:
240
233
241
234
if self .history .is_done ():
242
235
if (
243
- self .validate_output and step < max_steps - 1
236
+ self .validate_output and step < max_steps - 1
244
237
): # if last step, we dont need to validate
245
238
if not await self ._validate_output ():
246
239
continue
247
240
248
- logger .info (' ✅ Task completed successfully' )
241
+ logger .info (" ✅ Task completed successfully" )
249
242
break
250
243
else :
251
- logger .info (' ❌ Failed to complete task in maximum steps' )
244
+ logger .info (" ❌ Failed to complete task in maximum steps" )
252
245
253
246
return self .history
254
247
0 commit comments