Skip to content

Commit 863e865

Browse files
committed
add custom agent
1 parent 6c07ec2 commit 863e865

File tree

5 files changed

+461
-20
lines changed

5 files changed

+461
-20
lines changed

src/agent/custom_agent.py

Lines changed: 176 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,16 @@
55
# @FileName: custom_agent.py
66

77
import asyncio
8+
import base64
9+
import io
810
import json
911
import logging
1012
import os
13+
import pdb
14+
import textwrap
1115
import time
1216
import uuid
17+
from io import BytesIO
1318
from pathlib import Path
1419
from typing import Any, Optional, Type, TypeVar
1520

@@ -20,10 +25,12 @@
2025
SystemMessage,
2126
)
2227
from openai import RateLimitError
28+
from PIL import Image, ImageDraw, ImageFont
2329
from pydantic import BaseModel, ValidationError
2430

2531
from browser_use.agent.message_manager.service import MessageManager
2632
from browser_use.agent.prompts import AgentMessagePrompt, SystemPrompt
33+
from browser_use.agent.service import Agent
2734
from browser_use.agent.views import (
2835
ActionResult,
2936
AgentError,
@@ -32,21 +39,76 @@
3239
AgentOutput,
3340
AgentStepInfo,
3441
)
42+
from browser_use.browser.browser import Browser
43+
from browser_use.browser.context import BrowserContext
44+
from browser_use.browser.views import BrowserState, BrowserStateHistory
45+
from browser_use.controller.registry.views import ActionModel
46+
from browser_use.controller.service import Controller
47+
from browser_use.dom.history_tree_processor.service import (
48+
DOMHistoryElement,
49+
HistoryTreeProcessor,
50+
)
51+
from browser_use.telemetry.service import ProductTelemetry
3552
from browser_use.telemetry.views import (
3653
AgentEndTelemetryEvent,
3754
AgentRunTelemetryEvent,
3855
AgentStepErrorTelemetryEvent,
3956
)
40-
from browser_use.agent.service import Agent
4157
from browser_use.utils import time_execution_async
4258

43-
from .custom_views import CustomAgentOutput
59+
from .custom_views import CustomAgentOutput, CustomAgentStepInfo
60+
from .custom_massage_manager import CustomMassageManager
4461

4562
logger = logging.getLogger(__name__)
4663

4764

4865
class CustomAgent(Agent):
4966

67+
def __init__(
68+
self,
69+
task: str,
70+
llm: BaseChatModel,
71+
add_infos: str = '',
72+
browser: Browser | None = None,
73+
browser_context: BrowserContext | None = None,
74+
controller: Controller = Controller(),
75+
use_vision: bool = True,
76+
save_conversation_path: Optional[str] = None,
77+
max_failures: int = 5,
78+
retry_delay: int = 10,
79+
system_prompt_class: Type[SystemPrompt] = SystemPrompt,
80+
max_input_tokens: int = 128000,
81+
validate_output: bool = False,
82+
include_attributes: list[str] = [
83+
'title',
84+
'type',
85+
'name',
86+
'role',
87+
'tabindex',
88+
'aria-label',
89+
'placeholder',
90+
'value',
91+
'alt',
92+
'aria-expanded',
93+
],
94+
max_error_length: int = 400,
95+
max_actions_per_step: int = 10,
96+
):
97+
super().__init__(task, llm, browser, browser_context, controller, use_vision, save_conversation_path,
98+
max_failures, retry_delay, system_prompt_class, max_input_tokens, validate_output,
99+
include_attributes, max_error_length, max_actions_per_step)
100+
self.add_infos = add_infos
101+
self.message_manager = CustomMassageManager(
102+
llm=self.llm,
103+
task=self.task,
104+
action_descriptions=self.controller.registry.get_prompt_description(),
105+
system_prompt_class=self.system_prompt_class,
106+
max_input_tokens=self.max_input_tokens,
107+
include_attributes=self.include_attributes,
108+
max_error_length=self.max_error_length,
109+
max_actions_per_step=self.max_actions_per_step,
110+
)
111+
50112
def _setup_action_models(self) -> None:
51113
"""Setup dynamic action models from controller's registry"""
52114
# Get the dynamic action model from controller's registry
@@ -56,23 +118,42 @@ def _setup_action_models(self) -> None:
56118

57119
def _log_response(self, response: CustomAgentOutput) -> None:
58120
"""Log the model's response"""
59-
if 'Success' in response.current_state.evaluation_previous_goal:
60-
emoji = '👍'
61-
elif 'Failed' in response.current_state.evaluation_previous_goal:
62-
emoji = ''
121+
if 'Success' in response.current_state.prev_action_evaluation:
122+
emoji = ''
123+
elif 'Failed' in response.current_state.prev_action_evaluation:
124+
emoji = ''
63125
else:
64126
emoji = '🤷'
65127

66-
logger.info(f'{emoji} Eval: {response.current_state.evaluation_previous_goal}')
67-
logger.info(f'🧠 Memory: {response.current_state.memory}')
68-
logger.info(f'🎯 Next goal: {response.current_state.next_goal}')
128+
logger.info(f'{emoji} Eval: {response.current_state.prev_action_evaluation}')
129+
logger.info(f'🧠 Memory: {response.current_state.import_contents}')
130+
logger.info(f'⏳ Task Progress: {response.current_state.completed_contents}')
131+
logger.info(f'🤔 Thought: {response.current_state.thought}')
132+
logger.info(f'🎯 Summary: {response.current_state.summary}')
69133
for i, action in enumerate(response.action):
70134
logger.info(
71135
f'🛠️ Action {i + 1}/{len(response.action)}: {action.model_dump_json(exclude_unset=True)}'
72136
)
73137

138+
def update_step_info(self, model_output: CustomAgentOutput, step_info: CustomAgentStepInfo = None):
139+
"""
140+
update step info
141+
"""
142+
if step_info is None:
143+
return
144+
145+
step_info.step_number += 1
146+
import_contents = model_output.current_state.import_contents
147+
if import_contents and 'None' not in import_contents and import_contents not in step_info.memory:
148+
step_info.memory += import_contents + '\n'
149+
150+
completed_contents = model_output.current_state.completed_contents
151+
if completed_contents and 'None' not in completed_contents:
152+
step_info.task_progress = completed_contents
153+
154+
74155
@time_execution_async('--step')
75-
async def step(self, step_info: Optional[AgentStepInfo] = None) -> None:
156+
async def step(self, step_info: Optional[CustomAgentStepInfo] = None) -> None:
76157
"""Execute one step of the task"""
77158
logger.info(f'\n📍 Step {self.n_steps}')
78159
state = None
@@ -84,6 +165,7 @@ async def step(self, step_info: Optional[AgentStepInfo] = None) -> None:
84165
self.message_manager.add_state_message(state, self._last_result, step_info)
85166
input_messages = self.message_manager.get_messages()
86167
model_output = await self.get_next_action(input_messages)
168+
self.update_step_info(model_output, step_info)
87169
self._save_conversation(input_messages, model_output)
88170
self.message_manager._remove_last_state_message() # we dont want the whole state in the chat history
89171
self.message_manager.add_model_output(model_output)
@@ -115,3 +197,87 @@ async def step(self, step_info: Optional[AgentStepInfo] = None) -> None:
115197
)
116198
if state:
117199
self._make_history_item(model_output, state, result)
200+
201+
def _make_history_item(
202+
self,
203+
model_output: CustomAgentOutput | None,
204+
state: BrowserState,
205+
result: list[ActionResult],
206+
) -> None:
207+
"""Create and store history item"""
208+
interacted_element = None
209+
len_result = len(result)
210+
211+
if model_output:
212+
interacted_elements = AgentHistory.get_interacted_element(
213+
model_output, state.selector_map
214+
)
215+
else:
216+
interacted_elements = [None]
217+
218+
state_history = BrowserStateHistory(
219+
url=state.url,
220+
title=state.title,
221+
tabs=state.tabs,
222+
interacted_element=interacted_elements,
223+
screenshot=state.screenshot,
224+
)
225+
226+
history_item = AgentHistory(model_output=model_output, result=result, state=state_history)
227+
228+
self.history.history.append(history_item)
229+
230+
async def run(self, max_steps: int = 100) -> AgentHistoryList:
231+
"""Execute the task with maximum number of steps"""
232+
try:
233+
logger.info(f'🚀 Starting task: {self.task}')
234+
235+
self.telemetry.capture(
236+
AgentRunTelemetryEvent(
237+
agent_id=self.agent_id,
238+
task=self.task,
239+
)
240+
)
241+
242+
step_info = CustomAgentStepInfo(task=self.task,
243+
add_infos=self.add_infos,
244+
step_number=1,
245+
max_steps=max_steps,
246+
memory='',
247+
task_progress=''
248+
)
249+
250+
for step in range(max_steps):
251+
if self._too_many_failures():
252+
break
253+
254+
await self.step(step_info)
255+
256+
if self.history.is_done():
257+
if (
258+
self.validate_output and step < max_steps - 1
259+
): # if last step, we dont need to validate
260+
if not await self._validate_output():
261+
continue
262+
263+
logger.info('✅ Task completed successfully')
264+
break
265+
else:
266+
logger.info('❌ Failed to complete task in maximum steps')
267+
268+
return self.history
269+
270+
finally:
271+
self.telemetry.capture(
272+
AgentEndTelemetryEvent(
273+
agent_id=self.agent_id,
274+
task=self.task,
275+
success=self.history.is_done(),
276+
steps=len(self.history.history),
277+
)
278+
)
279+
if not self.injected_browser_context:
280+
await self.browser_context.close()
281+
282+
if not self.injected_browser and self.browser:
283+
await self.browser.close()
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# -*- coding: utf-8 -*-
2+
# @Time : 2025/1/2
3+
# @Author : wenshao
4+
# @ProjectName: browser-use-webui
5+
# @FileName: custom_massage_manager.py
6+
7+
from __future__ import annotations
8+
9+
import logging
10+
from datetime import datetime
11+
from typing import List, Optional, Type
12+
13+
from langchain_anthropic import ChatAnthropic
14+
from langchain_core.language_models import BaseChatModel
15+
from langchain_core.messages import (
16+
AIMessage,
17+
BaseMessage,
18+
HumanMessage,
19+
)
20+
from langchain_openai import ChatOpenAI
21+
22+
from browser_use.agent.message_manager.views import MessageHistory, MessageMetadata
23+
from browser_use.agent.prompts import AgentMessagePrompt, SystemPrompt
24+
from browser_use.agent.views import ActionResult, AgentOutput, AgentStepInfo
25+
from browser_use.browser.views import BrowserState
26+
from browser_use.agent.message_manager.service import MessageManager
27+
28+
from .custom_prompts import CustomAgentMessagePrompt
29+
30+
logger = logging.getLogger(__name__)
31+
32+
33+
class CustomMassageManager(MessageManager):
34+
def __init__(
35+
self,
36+
llm: BaseChatModel,
37+
task: str,
38+
action_descriptions: str,
39+
system_prompt_class: Type[SystemPrompt],
40+
max_input_tokens: int = 128000,
41+
estimated_tokens_per_character: int = 3,
42+
image_tokens: int = 800,
43+
include_attributes: list[str] = [],
44+
max_error_length: int = 400,
45+
max_actions_per_step: int = 10,
46+
):
47+
super().__init__(llm, task, action_descriptions, system_prompt_class, max_input_tokens,
48+
estimated_tokens_per_character, image_tokens, include_attributes, max_error_length,
49+
max_actions_per_step)
50+
51+
# Move Task info to state_message
52+
self.history = MessageHistory()
53+
self._add_message_with_tokens(self.system_prompt)
54+
55+
def add_state_message(
56+
self,
57+
state: BrowserState,
58+
result: Optional[List[ActionResult]] = None,
59+
step_info: Optional[AgentStepInfo] = None,
60+
) -> None:
61+
"""Add browser state as human message"""
62+
63+
# if keep in memory, add to directly to history and add state without result
64+
if result:
65+
for r in result:
66+
if r.include_in_memory:
67+
if r.extracted_content:
68+
msg = HumanMessage(content=str(r.extracted_content))
69+
self._add_message_with_tokens(msg)
70+
if r.error:
71+
msg = HumanMessage(content=str(r.error)[-self.max_error_length:])
72+
self._add_message_with_tokens(msg)
73+
result = None # if result in history, we dont want to add it again
74+
75+
# otherwise add state message and result to next message (which will not stay in memory)
76+
state_message = CustomAgentMessagePrompt(
77+
state,
78+
result,
79+
include_attributes=self.include_attributes,
80+
max_error_length=self.max_error_length,
81+
step_info=step_info,
82+
).get_user_message()
83+
self._add_message_with_tokens(state_message)

0 commit comments

Comments
 (0)