Skip to content

Commit 664dce7

Browse files
committed
add deepseek-r1 ollama
1 parent 6ceaf8d commit 664dce7

File tree

8 files changed

+94
-30
lines changed

8 files changed

+94
-30
lines changed

.env.example

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ AZURE_OPENAI_API_KEY=
1111
DEEPSEEK_ENDPOINT=https://api.deepseek.com
1212
DEEPSEEK_API_KEY=
1313

14+
OLLAMA_ENDPOINT=http://localhost:11434
15+
1416
# Set to false to disable anonymized telemetry
1517
ANONYMIZED_TELEMETRY=true
1618

src/agent/custom_agent.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def __init__(
9898
register_done_callback=register_done_callback,
9999
tool_calling_method=tool_calling_method
100100
)
101-
if self.model_name == "deepseek-reasoner":
101+
if self.model_name in ["deepseek-reasoner"] or self.model_name.startswith("deepseek-r1"):
102102
# deepseek-reasoner does not support function calling
103103
self.use_deepseek_r1 = True
104104
# deepseek-reasoner only support 64000 context
@@ -191,6 +191,7 @@ async def get_next_action(self, input_messages: list[BaseMessage]) -> AgentOutpu
191191
parsed_json = json.loads(ai_message.content.replace("```json", "").replace("```", ""))
192192
parsed: AgentOutput = self.AgentOutput(**parsed_json)
193193
if parsed is None:
194+
logger.debug(ai_message.content)
194195
raise ValueError(f'Could not parse response.')
195196
else:
196197
ai_message = self.llm.invoke(input_messages)
@@ -201,6 +202,7 @@ async def get_next_action(self, input_messages: list[BaseMessage]) -> AgentOutpu
201202
parsed_json = json.loads(ai_message.content.replace("```json", "").replace("```", ""))
202203
parsed: AgentOutput = self.AgentOutput(**parsed_json)
203204
if parsed is None:
205+
logger.debug(ai_message.content)
204206
raise ValueError(f'Could not parse response.')
205207

206208
# cut the number of actions to max_actions_per_step
@@ -229,6 +231,9 @@ async def step(self, step_info: Optional[CustomAgentStepInfo] = None) -> None:
229231
self.update_step_info(model_output, step_info)
230232
logger.info(f"🧠 All Memory: \n{step_info.memory}")
231233
self._save_conversation(input_messages, model_output)
234+
# should we remove last state message? at least, deepseek-reasoner cannot remove
235+
if self.model_name != "deepseek-reasoner":
236+
self.message_manager._remove_last_state_message()
232237
except Exception as e:
233238
# model call failed, remove last state message from history
234239
self.message_manager._remove_last_state_message()
@@ -253,7 +258,7 @@ async def step(self, step_info: Optional[CustomAgentStepInfo] = None) -> None:
253258
self.consecutive_failures = 0
254259

255260
except Exception as e:
256-
result = self._handle_step_error(e)
261+
result = await self._handle_step_error(e)
257262
self._last_result = result
258263

259264
finally:

src/agent/custom_prompts.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,7 @@ def important_rules(self) -> str:
2626
"summary": "Please generate a brief natural language description for the operation in next actions based on your Thought."
2727
},
2828
"action": [
29-
{
30-
"action_name": {
31-
// action-specific parameters
32-
}
33-
},
34-
// ... more actions in sequence
29+
* actions in sequences, please refer to **Common action sequences**. Each output action MUST be formated as: \{action_name\: action_params\}*
3530
]
3631
}
3732
@@ -44,7 +39,6 @@ def important_rules(self) -> str:
4439
{"click_element": {"index": 3}}
4540
]
4641
- Navigation and extraction: [
47-
{"open_new_tab": {}},
4842
{"go_to_url": {"url": "https://example.com"}},
4943
{"extract_page_content": {}}
5044
]
@@ -127,7 +121,7 @@ def get_system_message(self) -> SystemMessage:
127121
AGENT_PROMPT = f"""You are a precise browser automation agent that interacts with websites through structured commands. Your role is to:
128122
1. Analyze the provided webpage elements and structure
129123
2. Plan a sequence of actions to accomplish the given task
130-
3. Respond with valid JSON containing your action sequence and state assessment
124+
3. Your final result MUST be a valid JSON as the **RESPONSE FORMAT** described, containing your action sequence and state assessment, No need extra content to expalin.
131125
132126
Current date and time: {time_str}
133127
@@ -200,15 +194,16 @@ def get_user_message(self) -> HumanMessage:
200194
"""
201195

202196
if self.result:
197+
203198
for i, result in enumerate(self.result):
204199
if result.include_in_memory:
205200
if result.extracted_content:
206-
state_description += f"\nResult of action {i + 1}/{len(self.result)}: {result.extracted_content}"
201+
state_description += f"\nResult of previous action {i + 1}/{len(self.result)}: {result.extracted_content}"
207202
if result.error:
208203
# only use last 300 characters of error
209204
error = result.error[-self.max_error_length:]
210205
state_description += (
211-
f"\nError of action {i + 1}/{len(self.result)}: ...{error}"
206+
f"\nError of previous action {i + 1}/{len(self.result)}: ...{error}"
212207
)
213208

214209
if self.state.screenshot:

src/utils/llm.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
LLMResult,
2626
RunInfo,
2727
)
28+
from langchain_ollama import ChatOllama
2829
from langchain_core.output_parsers.base import OutputParserLike
2930
from langchain_core.runnables import Runnable, RunnableConfig
3031
from langchain_core.tools import BaseTool
@@ -98,4 +99,38 @@ def invoke(
9899

99100
reasoning_content = response.choices[0].message.reasoning_content
100101
content = response.choices[0].message.content
102+
return AIMessage(content=content, reasoning_content=reasoning_content)
103+
104+
class DeepSeekR1ChatOllama(ChatOllama):
105+
106+
async def ainvoke(
107+
self,
108+
input: LanguageModelInput,
109+
config: Optional[RunnableConfig] = None,
110+
*,
111+
stop: Optional[list[str]] = None,
112+
**kwargs: Any,
113+
) -> AIMessage:
114+
org_ai_message = await super().ainvoke(input=input)
115+
org_content = org_ai_message.content
116+
reasoning_content = org_content.split("</think>")[0].replace("<think>", "")
117+
content = org_content.split("</think>")[1]
118+
if "**JSON Response:**" in content:
119+
content = content.split("**JSON Response:**")[-1]
120+
return AIMessage(content=content, reasoning_content=reasoning_content)
121+
122+
def invoke(
123+
self,
124+
input: LanguageModelInput,
125+
config: Optional[RunnableConfig] = None,
126+
*,
127+
stop: Optional[list[str]] = None,
128+
**kwargs: Any,
129+
) -> AIMessage:
130+
org_ai_message = super().invoke(input=input)
131+
org_content = org_ai_message.content
132+
reasoning_content = org_content.split("</think>")[0].replace("<think>", "")
133+
content = org_content.split("</think>")[1]
134+
if "**JSON Response:**" in content:
135+
content = content.split("**JSON Response:**")[-1]
101136
return AIMessage(content=content, reasoning_content=reasoning_content)

src/utils/utils.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from langchain_openai import AzureChatOpenAI, ChatOpenAI
1111
import gradio as gr
1212

13-
from .llm import DeepSeekR1ChatOpenAI
13+
from .llm import DeepSeekR1ChatOpenAI, DeepSeekR1ChatOllama
1414

1515
def get_llm_model(provider: str, **kwargs):
1616
"""
@@ -89,12 +89,25 @@ def get_llm_model(provider: str, **kwargs):
8989
google_api_key=api_key,
9090
)
9191
elif provider == "ollama":
92-
return ChatOllama(
93-
model=kwargs.get("model_name", "qwen2.5:7b"),
94-
temperature=kwargs.get("temperature", 0.0),
95-
num_ctx=kwargs.get("num_ctx", 32000),
96-
base_url=kwargs.get("base_url", "http://localhost:11434"),
97-
)
92+
if not kwargs.get("base_url", ""):
93+
base_url = os.getenv("OLLAMA_ENDPOINT", "http://localhost:11434")
94+
else:
95+
base_url = kwargs.get("base_url")
96+
97+
if kwargs.get("model_name", "qwen2.5:7b").startswith("deepseek-r1"):
98+
return DeepSeekR1ChatOllama(
99+
model=kwargs.get("model_name", "deepseek-r1:7b"),
100+
temperature=kwargs.get("temperature", 0.0),
101+
num_ctx=kwargs.get("num_ctx", 32000),
102+
base_url=kwargs.get("base_url", base_url),
103+
)
104+
else:
105+
return ChatOllama(
106+
model=kwargs.get("model_name", "qwen2.5:7b"),
107+
temperature=kwargs.get("temperature", 0.0),
108+
num_ctx=kwargs.get("num_ctx", 32000),
109+
base_url=kwargs.get("base_url", base_url),
110+
)
98111
elif provider == "azure_openai":
99112
if not kwargs.get("base_url", ""):
100113
base_url = os.getenv("AZURE_OPENAI_ENDPOINT", "")
@@ -120,7 +133,7 @@ def get_llm_model(provider: str, **kwargs):
120133
"openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"],
121134
"deepseek": ["deepseek-chat", "deepseek-reasoner"],
122135
"gemini": ["gemini-2.0-flash-exp", "gemini-2.0-flash-thinking-exp", "gemini-1.5-flash-latest", "gemini-1.5-flash-8b-latest", "gemini-2.0-flash-thinking-exp-1219" ],
123-
"ollama": ["qwen2.5:7b", "llama2:7b"],
136+
"ollama": ["qwen2.5:7b", "llama2:7b", "deepseek-r1:14b", "deepseek-r1:32b"],
124137
"azure_openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"]
125138
}
126139

tests/test_browser_use.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -257,22 +257,26 @@ async def test_browser_use_custom_v2():
257257
# temperature=0.8
258258
# )
259259

260-
llm = utils.get_llm_model(
261-
provider="deepseek",
262-
model_name="deepseek-chat",
263-
temperature=0.8
264-
)
260+
# llm = utils.get_llm_model(
261+
# provider="deepseek",
262+
# model_name="deepseek-chat",
263+
# temperature=0.8
264+
# )
265265

266266
# llm = utils.get_llm_model(
267267
# provider="ollama", model_name="qwen2.5:7b", temperature=0.5
268268
# )
269+
270+
# llm = utils.get_llm_model(
271+
# provider="ollama", model_name="deepseek-r1:14b", temperature=0.5
272+
# )
269273

270274
controller = CustomController()
271275
use_own_browser = False
272276
disable_security = True
273277
use_vision = False # Set to False when using DeepSeek
274278

275-
max_actions_per_step = 1
279+
max_actions_per_step = 10
276280
playwright = None
277281
browser = None
278282
browser_context = None
@@ -303,7 +307,7 @@ async def test_browser_use_custom_v2():
303307
)
304308
)
305309
agent = CustomAgent(
306-
task="give me stock price of Nvidia and tesla",
310+
task="go to google.com and type 'Nvidia' click search and give me the first url",
307311
add_infos="", # some hints for llm to complete the task
308312
llm=llm,
309313
browser=browser,

tests/test_llm_api.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,12 +142,21 @@ def test_ollama_model():
142142
llm = ChatOllama(model="qwen2.5:7b")
143143
ai_msg = llm.invoke("Sing a ballad of LangChain.")
144144
print(ai_msg.content)
145+
146+
def test_deepseek_r1_ollama_model():
147+
from src.utils.llm import DeepSeekR1ChatOllama
148+
149+
llm = DeepSeekR1ChatOllama(model="deepseek-r1:14b")
150+
ai_msg = llm.invoke("how many r in strawberry?")
151+
print(ai_msg.content)
152+
pdb.set_trace()
145153

146154

147155
if __name__ == '__main__':
148156
# test_openai_model()
149157
# test_gemini_model()
150158
# test_azure_openai_model()
151-
test_deepseek_model()
159+
# test_deepseek_model()
152160
# test_ollama_model()
153-
# test_deepseek_r1_model()
161+
test_deepseek_r1_model()
162+
# test_deepseek_r1_ollama_model()

webui.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -658,7 +658,8 @@ def create_ui(config, theme_name="Ocean"):
658658
interactive=True,
659659
allow_custom_value=True, # Allow users to input custom model names
660660
choices=["auto", "json_schema", "function_calling"],
661-
info="Tool Calls Funtion Name"
661+
info="Tool Calls Funtion Name",
662+
visible=False
662663
)
663664

664665
with gr.TabItem("🔧 LLM Configuration", id=2):

0 commit comments

Comments
 (0)