Skip to content

Commit 3740b93

Browse files
Update custom_agent.py
1 parent 83b08e4 commit 3740b93

File tree

1 file changed

+103
-86
lines changed

1 file changed

+103
-86
lines changed

src/agent/custom_agent.py

Lines changed: 103 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -4,71 +4,45 @@
44
# @ProjectName: browser-use-webui
55
# @FileName: custom_agent.py
66

7-
import asyncio
8-
import base64
9-
import io
107
import json
118
import logging
12-
import os
139
import pdb
14-
import textwrap
15-
import time
16-
import uuid
17-
from io import BytesIO
18-
from pathlib import Path
19-
from typing import Any, Optional, Type, TypeVar
20-
21-
from dotenv import load_dotenv
22-
from langchain_core.language_models.chat_models import BaseChatModel
23-
from langchain_core.messages import (
24-
BaseMessage,
25-
SystemMessage,
26-
)
27-
from openai import RateLimitError
28-
from PIL import Image, ImageDraw, ImageFont
29-
from pydantic import BaseModel, ValidationError
10+
import traceback
11+
from typing import Optional, Type
3012

31-
from browser_use.agent.message_manager.service import MessageManager
32-
from browser_use.agent.prompts import AgentMessagePrompt, SystemPrompt
13+
from browser_use.agent.prompts import SystemPrompt
3314
from browser_use.agent.service import Agent
3415
from browser_use.agent.views import (
3516
ActionResult,
36-
AgentError,
37-
AgentHistory,
3817
AgentHistoryList,
3918
AgentOutput,
40-
AgentStepInfo,
4119
)
4220
from browser_use.browser.browser import Browser
4321
from browser_use.browser.context import BrowserContext
44-
from browser_use.browser.views import BrowserState, BrowserStateHistory
45-
from browser_use.controller.registry.views import ActionModel
4622
from browser_use.controller.service import Controller
47-
from browser_use.dom.history_tree_processor.service import (
48-
DOMHistoryElement,
49-
HistoryTreeProcessor,
50-
)
51-
from browser_use.telemetry.service import ProductTelemetry
5223
from browser_use.telemetry.views import (
5324
AgentEndTelemetryEvent,
5425
AgentRunTelemetryEvent,
5526
AgentStepErrorTelemetryEvent,
5627
)
5728
from browser_use.utils import time_execution_async
29+
from langchain_core.language_models.chat_models import BaseChatModel
30+
from langchain_core.messages import (
31+
BaseMessage,
32+
)
5833

59-
from .custom_views import CustomAgentOutput, CustomAgentStepInfo
6034
from .custom_massage_manager import CustomMassageManager
35+
from .custom_views import CustomAgentOutput, CustomAgentStepInfo
6136

6237
logger = logging.getLogger(__name__)
6338

6439

6540
class CustomAgent(Agent):
66-
6741
def __init__(
6842
self,
6943
task: str,
7044
llm: BaseChatModel,
71-
add_infos: str = '',
45+
add_infos: str = "",
7246
browser: Browser | None = None,
7347
browser_context: BrowserContext | None = None,
7448
controller: Controller = Controller(),
@@ -80,23 +54,39 @@ def __init__(
8054
max_input_tokens: int = 128000,
8155
validate_output: bool = False,
8256
include_attributes: list[str] = [
83-
'title',
84-
'type',
85-
'name',
86-
'role',
87-
'tabindex',
88-
'aria-label',
89-
'placeholder',
90-
'value',
91-
'alt',
92-
'aria-expanded',
57+
"title",
58+
"type",
59+
"name",
60+
"role",
61+
"tabindex",
62+
"aria-label",
63+
"placeholder",
64+
"value",
65+
"alt",
66+
"aria-expanded",
9367
],
9468
max_error_length: int = 400,
9569
max_actions_per_step: int = 10,
70+
tool_call_in_content: bool = True,
9671
):
97-
super().__init__(task, llm, browser, browser_context, controller, use_vision, save_conversation_path,
98-
max_failures, retry_delay, system_prompt_class, max_input_tokens, validate_output,
99-
include_attributes, max_error_length, max_actions_per_step)
72+
super().__init__(
73+
task=task,
74+
llm=llm,
75+
browser=browser,
76+
browser_context=browser_context,
77+
controller=controller,
78+
use_vision=use_vision,
79+
save_conversation_path=save_conversation_path,
80+
max_failures=max_failures,
81+
retry_delay=retry_delay,
82+
system_prompt_class=system_prompt_class,
83+
max_input_tokens=max_input_tokens,
84+
validate_output=validate_output,
85+
include_attributes=include_attributes,
86+
max_error_length=max_error_length,
87+
max_actions_per_step=max_actions_per_step,
88+
tool_call_in_content=tool_call_in_content,
89+
)
10090
self.add_infos = add_infos
10191
self.message_manager = CustomMassageManager(
10292
llm=self.llm,
@@ -107,6 +97,7 @@ def __init__(
10797
include_attributes=self.include_attributes,
10898
max_error_length=self.max_error_length,
10999
max_actions_per_step=self.max_actions_per_step,
100+
tool_call_in_content=tool_call_in_content,
110101
)
111102

112103
def _setup_action_models(self) -> None:
@@ -118,24 +109,26 @@ def _setup_action_models(self) -> None:
118109

119110
def _log_response(self, response: CustomAgentOutput) -> None:
120111
"""Log the model's response"""
121-
if 'Success' in response.current_state.prev_action_evaluation:
122-
emoji = '✅'
123-
elif 'Failed' in response.current_state.prev_action_evaluation:
124-
emoji = '❌'
112+
if "Success" in response.current_state.prev_action_evaluation:
113+
emoji = "✅"
114+
elif "Failed" in response.current_state.prev_action_evaluation:
115+
emoji = "❌"
125116
else:
126-
emoji = '🤷'
117+
emoji = "🤷"
127118

128-
logger.info(f'{emoji} Eval: {response.current_state.prev_action_evaluation}')
129-
logger.info(f'🧠 New Memory: {response.current_state.important_contents}')
130-
logger.info(f'⏳ Task Progress: {response.current_state.completed_contents}')
131-
logger.info(f'🤔 Thought: {response.current_state.thought}')
132-
logger.info(f'🎯 Summary: {response.current_state.summary}')
119+
logger.info(f"{emoji} Eval: {response.current_state.prev_action_evaluation}")
120+
logger.info(f"🧠 New Memory: {response.current_state.important_contents}")
121+
logger.info(f"⏳ Task Progress: {response.current_state.completed_contents}")
122+
logger.info(f"🤔 Thought: {response.current_state.thought}")
123+
logger.info(f"🎯 Summary: {response.current_state.summary}")
133124
for i, action in enumerate(response.action):
134125
logger.info(
135-
f'🛠️ Action {i + 1}/{len(response.action)}: {action.model_dump_json(exclude_unset=True)}'
126+
f"🛠️ Action {i + 1}/{len(response.action)}: {action.model_dump_json(exclude_unset=True)}"
136127
)
137128

138-
def update_step_info(self, model_output: CustomAgentOutput, step_info: CustomAgentStepInfo = None):
129+
def update_step_info(
130+
self, model_output: CustomAgentOutput, step_info: CustomAgentStepInfo = None
131+
):
139132
"""
140133
update step info
141134
"""
@@ -144,31 +137,54 @@ def update_step_info(self, model_output: CustomAgentOutput, step_info: CustomAge
144137

145138
step_info.step_number += 1
146139
important_contents = model_output.current_state.important_contents
147-
if important_contents and 'None' not in important_contents and important_contents not in step_info.memory:
148-
step_info.memory += important_contents + '\n'
140+
if (
141+
important_contents
142+
and "None" not in important_contents
143+
and important_contents not in step_info.memory
144+
):
145+
step_info.memory += important_contents + "\n"
149146

150147
completed_contents = model_output.current_state.completed_contents
151-
if completed_contents and 'None' not in completed_contents:
148+
if completed_contents and "None" not in completed_contents:
152149
step_info.task_progress = completed_contents
153150

154-
@time_execution_async('--get_next_action')
151+
@time_execution_async("--get_next_action")
155152
async def get_next_action(self, input_messages: list[BaseMessage]) -> AgentOutput:
156153
"""Get next action from LLM based on current state"""
154+
try:
155+
structured_llm = self.llm.with_structured_output(self.AgentOutput, include_raw=True)
156+
response: dict[str, Any] = await structured_llm.ainvoke(input_messages) # type: ignore
157157

158-
ret = self.llm.invoke(input_messages)
159-
parsed_json = json.loads(ret.content.replace('```json', '').replace("```", ""))
160-
parsed: AgentOutput = self.AgentOutput(**parsed_json)
161-
# cut the number of actions to max_actions_per_step
162-
parsed.action = parsed.action[: self.max_actions_per_step]
163-
self._log_response(parsed)
164-
self.n_steps += 1
158+
parsed: AgentOutput = response['parsed']
159+
# cut the number of actions to max_actions_per_step
160+
parsed.action = parsed.action[: self.max_actions_per_step]
161+
self._log_response(parsed)
162+
self.n_steps += 1
165163

166-
return parsed
164+
return parsed
165+
except Exception as e:
166+
# If something goes wrong, try to invoke the LLM again without structured output,
167+
# and Manually parse the response. Temporarily solution for DeepSeek
168+
ret = self.llm.invoke(input_messages)
169+
if isinstance(ret.content, list):
170+
parsed_json = json.loads(ret.content[0].replace("```json", "").replace("```", ""))
171+
else:
172+
parsed_json = json.loads(ret.content.replace("```json", "").replace("```", ""))
173+
parsed: AgentOutput = self.AgentOutput(**parsed_json)
174+
if parsed is None:
175+
raise ValueError(f'Could not parse response.')
176+
177+
# cut the number of actions to max_actions_per_step
178+
parsed.action = parsed.action[: self.max_actions_per_step]
179+
self._log_response(parsed)
180+
self.n_steps += 1
167181

168-
@time_execution_async('--step')
182+
return parsed
183+
184+
@time_execution_async("--step")
169185
async def step(self, step_info: Optional[CustomAgentStepInfo] = None) -> None:
170186
"""Execute one step of the task"""
171-
logger.info(f'\n📍 Step {self.n_steps}')
187+
logger.info(f"\n📍 Step {self.n_steps}")
172188
state = None
173189
model_output = None
174190
result: list[ActionResult] = []
@@ -179,7 +195,7 @@ async def step(self, step_info: Optional[CustomAgentStepInfo] = None) -> None:
179195
input_messages = self.message_manager.get_messages()
180196
model_output = await self.get_next_action(input_messages)
181197
self.update_step_info(model_output, step_info)
182-
logger.info(f'🧠 All Memory: {step_info.memory}')
198+
logger.info(f"🧠 All Memory: {step_info.memory}")
183199
self._save_conversation(input_messages, model_output)
184200
self.message_manager._remove_last_state_message() # we dont want the whole state in the chat history
185201
self.message_manager.add_model_output(model_output)
@@ -190,7 +206,7 @@ async def step(self, step_info: Optional[CustomAgentStepInfo] = None) -> None:
190206
self._last_result = result
191207

192208
if len(result) > 0 and result[-1].is_done:
193-
logger.info(f'📄 Result: {result[-1].extracted_content}')
209+
logger.info(f"📄 Result: {result[-1].extracted_content}")
194210

195211
self.consecutive_failures = 0
196212

@@ -215,7 +231,7 @@ async def step(self, step_info: Optional[CustomAgentStepInfo] = None) -> None:
215231
async def run(self, max_steps: int = 100) -> AgentHistoryList:
216232
"""Execute the task with maximum number of steps"""
217233
try:
218-
logger.info(f'🚀 Starting task: {self.task}')
234+
logger.info(f"🚀 Starting task: {self.task}")
219235

220236
self.telemetry.capture(
221237
AgentRunTelemetryEvent(
@@ -224,13 +240,14 @@ async def run(self, max_steps: int = 100) -> AgentHistoryList:
224240
)
225241
)
226242

227-
step_info = CustomAgentStepInfo(task=self.task,
228-
add_infos=self.add_infos,
229-
step_number=1,
230-
max_steps=max_steps,
231-
memory='',
232-
task_progress=''
233-
)
243+
step_info = CustomAgentStepInfo(
244+
task=self.task,
245+
add_infos=self.add_infos,
246+
step_number=1,
247+
max_steps=max_steps,
248+
memory="",
249+
task_progress="",
250+
)
234251

235252
for step in range(max_steps):
236253
if self._too_many_failures():
@@ -245,10 +262,10 @@ async def run(self, max_steps: int = 100) -> AgentHistoryList:
245262
if not await self._validate_output():
246263
continue
247264

248-
logger.info('✅ Task completed successfully')
265+
logger.info("✅ Task completed successfully")
249266
break
250267
else:
251-
logger.info('❌ Failed to complete task in maximum steps')
268+
logger.info("❌ Failed to complete task in maximum steps")
252269

253270
return self.history
254271

0 commit comments

Comments
 (0)