5
5
# @FileName: custom_agent.py
6
6
7
7
import asyncio
8
+ import base64
9
+ import io
8
10
import json
9
11
import logging
10
12
import os
13
+ import pdb
14
+ import textwrap
11
15
import time
12
16
import uuid
17
+ from io import BytesIO
13
18
from pathlib import Path
14
19
from typing import Any , Optional , Type , TypeVar
15
20
20
25
SystemMessage ,
21
26
)
22
27
from openai import RateLimitError
28
+ from PIL import Image , ImageDraw , ImageFont
23
29
from pydantic import BaseModel , ValidationError
24
30
25
31
from browser_use .agent .message_manager .service import MessageManager
26
32
from browser_use .agent .prompts import AgentMessagePrompt , SystemPrompt
33
+ from browser_use .agent .service import Agent
27
34
from browser_use .agent .views import (
28
35
ActionResult ,
29
36
AgentError ,
32
39
AgentOutput ,
33
40
AgentStepInfo ,
34
41
)
42
+ from browser_use .browser .browser import Browser
43
+ from browser_use .browser .context import BrowserContext
44
+ from browser_use .browser .views import BrowserState , BrowserStateHistory
45
+ from browser_use .controller .registry .views import ActionModel
46
+ from browser_use .controller .service import Controller
47
+ from browser_use .dom .history_tree_processor .service import (
48
+ DOMHistoryElement ,
49
+ HistoryTreeProcessor ,
50
+ )
51
+ from browser_use .telemetry .service import ProductTelemetry
35
52
from browser_use .telemetry .views import (
36
53
AgentEndTelemetryEvent ,
37
54
AgentRunTelemetryEvent ,
38
55
AgentStepErrorTelemetryEvent ,
39
56
)
40
- from browser_use .agent .service import Agent
41
57
from browser_use .utils import time_execution_async
42
58
43
- from .custom_views import CustomAgentOutput
59
+ from .custom_views import CustomAgentOutput , CustomAgentStepInfo
60
+ from .custom_massage_manager import CustomMassageManager
44
61
45
62
logger = logging .getLogger (__name__ )
46
63
47
64
48
65
class CustomAgent (Agent ):
49
66
67
+ def __init__ (
68
+ self ,
69
+ task : str ,
70
+ llm : BaseChatModel ,
71
+ add_infos : str = '' ,
72
+ browser : Browser | None = None ,
73
+ browser_context : BrowserContext | None = None ,
74
+ controller : Controller = Controller (),
75
+ use_vision : bool = True ,
76
+ save_conversation_path : Optional [str ] = None ,
77
+ max_failures : int = 5 ,
78
+ retry_delay : int = 10 ,
79
+ system_prompt_class : Type [SystemPrompt ] = SystemPrompt ,
80
+ max_input_tokens : int = 128000 ,
81
+ validate_output : bool = False ,
82
+ include_attributes : list [str ] = [
83
+ 'title' ,
84
+ 'type' ,
85
+ 'name' ,
86
+ 'role' ,
87
+ 'tabindex' ,
88
+ 'aria-label' ,
89
+ 'placeholder' ,
90
+ 'value' ,
91
+ 'alt' ,
92
+ 'aria-expanded' ,
93
+ ],
94
+ max_error_length : int = 400 ,
95
+ max_actions_per_step : int = 10 ,
96
+ ):
97
+ super ().__init__ (task , llm , browser , browser_context , controller , use_vision , save_conversation_path ,
98
+ max_failures , retry_delay , system_prompt_class , max_input_tokens , validate_output ,
99
+ include_attributes , max_error_length , max_actions_per_step )
100
+ self .add_infos = add_infos
101
+ self .message_manager = CustomMassageManager (
102
+ llm = self .llm ,
103
+ task = self .task ,
104
+ action_descriptions = self .controller .registry .get_prompt_description (),
105
+ system_prompt_class = self .system_prompt_class ,
106
+ max_input_tokens = self .max_input_tokens ,
107
+ include_attributes = self .include_attributes ,
108
+ max_error_length = self .max_error_length ,
109
+ max_actions_per_step = self .max_actions_per_step ,
110
+ )
111
+
50
112
def _setup_action_models (self ) -> None :
51
113
"""Setup dynamic action models from controller's registry"""
52
114
# Get the dynamic action model from controller's registry
@@ -56,23 +118,42 @@ def _setup_action_models(self) -> None:
56
118
57
119
def _log_response (self , response : CustomAgentOutput ) -> None :
58
120
"""Log the model's response"""
59
- if 'Success' in response .current_state .evaluation_previous_goal :
60
- emoji = '👍 '
61
- elif 'Failed' in response .current_state .evaluation_previous_goal :
62
- emoji = '⚠ '
121
+ if 'Success' in response .current_state .prev_action_evaluation :
122
+ emoji = '✅ '
123
+ elif 'Failed' in response .current_state .prev_action_evaluation :
124
+ emoji = '❌ '
63
125
else :
64
126
emoji = '🤷'
65
127
66
- logger .info (f'{ emoji } Eval: { response .current_state .evaluation_previous_goal } ' )
67
- logger .info (f'🧠 Memory: { response .current_state .memory } ' )
68
- logger .info (f'🎯 Next goal: { response .current_state .next_goal } ' )
128
+ logger .info (f'{ emoji } Eval: { response .current_state .prev_action_evaluation } ' )
129
+ logger .info (f'🧠 Memory: { response .current_state .import_contents } ' )
130
+ logger .info (f'⏳ Task Progress: { response .current_state .completed_contents } ' )
131
+ logger .info (f'🤔 Thought: { response .current_state .thought } ' )
132
+ logger .info (f'🎯 Summary: { response .current_state .summary } ' )
69
133
for i , action in enumerate (response .action ):
70
134
logger .info (
71
135
f'🛠️ Action { i + 1 } /{ len (response .action )} : { action .model_dump_json (exclude_unset = True )} '
72
136
)
73
137
138
+ def update_step_info (self , model_output : CustomAgentOutput , step_info : CustomAgentStepInfo = None ):
139
+ """
140
+ update step info
141
+ """
142
+ if step_info is None :
143
+ return
144
+
145
+ step_info .step_number += 1
146
+ import_contents = model_output .current_state .import_contents
147
+ if import_contents and 'None' not in import_contents and import_contents not in step_info .memory :
148
+ step_info .memory += import_contents + '\n '
149
+
150
+ completed_contents = model_output .current_state .completed_contents
151
+ if completed_contents and 'None' not in completed_contents :
152
+ step_info .task_progress = completed_contents
153
+
154
+
74
155
@time_execution_async ('--step' )
75
- async def step (self , step_info : Optional [AgentStepInfo ] = None ) -> None :
156
+ async def step (self , step_info : Optional [CustomAgentStepInfo ] = None ) -> None :
76
157
"""Execute one step of the task"""
77
158
logger .info (f'\n 📍 Step { self .n_steps } ' )
78
159
state = None
@@ -84,6 +165,7 @@ async def step(self, step_info: Optional[AgentStepInfo] = None) -> None:
84
165
self .message_manager .add_state_message (state , self ._last_result , step_info )
85
166
input_messages = self .message_manager .get_messages ()
86
167
model_output = await self .get_next_action (input_messages )
168
+ self .update_step_info (model_output , step_info )
87
169
self ._save_conversation (input_messages , model_output )
88
170
self .message_manager ._remove_last_state_message () # we dont want the whole state in the chat history
89
171
self .message_manager .add_model_output (model_output )
@@ -115,3 +197,87 @@ async def step(self, step_info: Optional[AgentStepInfo] = None) -> None:
115
197
)
116
198
if state :
117
199
self ._make_history_item (model_output , state , result )
200
+
201
+ def _make_history_item (
202
+ self ,
203
+ model_output : CustomAgentOutput | None ,
204
+ state : BrowserState ,
205
+ result : list [ActionResult ],
206
+ ) -> None :
207
+ """Create and store history item"""
208
+ interacted_element = None
209
+ len_result = len (result )
210
+
211
+ if model_output :
212
+ interacted_elements = AgentHistory .get_interacted_element (
213
+ model_output , state .selector_map
214
+ )
215
+ else :
216
+ interacted_elements = [None ]
217
+
218
+ state_history = BrowserStateHistory (
219
+ url = state .url ,
220
+ title = state .title ,
221
+ tabs = state .tabs ,
222
+ interacted_element = interacted_elements ,
223
+ screenshot = state .screenshot ,
224
+ )
225
+
226
+ history_item = AgentHistory (model_output = model_output , result = result , state = state_history )
227
+
228
+ self .history .history .append (history_item )
229
+
230
+ async def run (self , max_steps : int = 100 ) -> AgentHistoryList :
231
+ """Execute the task with maximum number of steps"""
232
+ try :
233
+ logger .info (f'🚀 Starting task: { self .task } ' )
234
+
235
+ self .telemetry .capture (
236
+ AgentRunTelemetryEvent (
237
+ agent_id = self .agent_id ,
238
+ task = self .task ,
239
+ )
240
+ )
241
+
242
+ step_info = CustomAgentStepInfo (task = self .task ,
243
+ add_infos = self .add_infos ,
244
+ step_number = 1 ,
245
+ max_steps = max_steps ,
246
+ memory = '' ,
247
+ task_progress = ''
248
+ )
249
+
250
+ for step in range (max_steps ):
251
+ if self ._too_many_failures ():
252
+ break
253
+
254
+ await self .step (step_info )
255
+
256
+ if self .history .is_done ():
257
+ if (
258
+ self .validate_output and step < max_steps - 1
259
+ ): # if last step, we dont need to validate
260
+ if not await self ._validate_output ():
261
+ continue
262
+
263
+ logger .info ('✅ Task completed successfully' )
264
+ break
265
+ else :
266
+ logger .info ('❌ Failed to complete task in maximum steps' )
267
+
268
+ return self .history
269
+
270
+ finally :
271
+ self .telemetry .capture (
272
+ AgentEndTelemetryEvent (
273
+ agent_id = self .agent_id ,
274
+ task = self .task ,
275
+ success = self .history .is_done (),
276
+ steps = len (self .history .history ),
277
+ )
278
+ )
279
+ if not self .injected_browser_context :
280
+ await self .browser_context .close ()
281
+
282
+ if not self .injected_browser and self .browser :
283
+ await self .browser .close ()
0 commit comments