Skip to content

Commit 3efe706

Browse files
authored
Merge pull request #3 from danilaplee/ollama
feat: add ollama
2 parents 2905a53 + 7220a11 commit 3efe706

File tree

8 files changed

+112
-19
lines changed

8 files changed

+112
-19
lines changed

.github/workflows/docker.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ jobs:
3737
id: meta
3838
uses: docker/metadata-action@9ec57ed1fcdbf14dcef7dfbe97b2010124a938b7
3939
with:
40-
images: browseruse/browser-use
40+
images: browseruser/browser-use
4141

4242
- name: Build and push Docker image
4343
id: push
@@ -49,5 +49,5 @@ jobs:
4949
push: true
5050
tags: ${{ steps.meta.outputs.tags }}
5151
labels: ${{ steps.meta.outputs.labels }}
52-
cache-from: type=registry,ref=browseruse/browser-use:buildcache
53-
cache-to: type=registry,ref=browseruse/browser-use:buildcache,mode=max
52+
cache-from: type=registry,ref=browseruser/browser-user:buildcache
53+
cache-to: type=registry,ref=browseruser/browser-user:buildcache,mode=max

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
__pycache__/
33
*.py[cod]
44
*$py.class
5-
5+
ollama
66
# C extensions
77
*.so
88

Dockerfile

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,8 @@ RUN playwright install-deps
141141

142142
RUN apt-get install xauth -y
143143

144+
RUN pip install --no-cache-dir langchain-ollama
145+
144146
# ensure correct permissions for /tmp/.X11-unix to prevent Xvfb from issuing warnings
145147
RUN mkdir -p /tmp/.X11-unix && chmod 1777 /tmp/.X11-unix
146148

@@ -161,7 +163,7 @@ RUN chown -R appuser:appuser /app
161163
USER appuser
162164

163165
# Expose port
164-
EXPOSE 8000
166+
EXPOSE 9000
165167

166168
# Command to start the application
167169
CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "8000"]

api.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def calculate_max_tasks():
130130

131131
# Function to execute a task
132132
async def execute_task(task_id: int, task: str, config: Dict[str, Any], db: Session):
133+
result = None
133134
try:
134135
# Update status to running
135136
db_task = db.query(Task).filter(Task.id == task_id).first()
@@ -148,7 +149,11 @@ async def execute_task(task_id: int, task: str, config: Dict[str, Any], db: Sess
148149
if db_task:
149150
db_task.status = "completed"
150151
db_task.result = json.dumps({
151-
"videopath":result.videopath
152+
"videopath":result.videopath,
153+
"result":result.result,
154+
"task":result.task,
155+
"steps_executed":result.steps_executed,
156+
"success":result.success
152157
})
153158
db_task.completed_at = datetime.utcnow()
154159
db.commit()
@@ -158,9 +163,12 @@ async def execute_task(task_id: int, task: str, config: Dict[str, Any], db: Sess
158163
if db_task:
159164
db_task.status = "failed"
160165
db_task.error = str(e)
161-
db_task.result = json.dumps({
162-
"videopath":result.videopath
163-
})
166+
if result != None :
167+
db_task.result = json.dumps({
168+
"videopath":result.videopath
169+
})
170+
else :
171+
db_task.result = json.dumps({})
164172
db_task.completed_at = datetime.utcnow()
165173
db.commit()
166174
await send_error_to_webhook(str(e), "execute_task", task_id)

browser.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ async def execute_task(self, task: str, config: Dict[str, Any]) -> AgentResponse
5757
# Initialize browser
5858
browser = Browser(config=browser_config)
5959

60+
tool_calling_method = "auto"
61+
if "deepseek-r1" in llm_config.model_name:
62+
tool_calling_method = "json_mode"
6063
# Initialize and run agent
6164
agent = Agent(
6265
task=task,
@@ -65,7 +68,8 @@ async def execute_task(self, task: str, config: Dict[str, Any]) -> AgentResponse
6568
max_failures=config.get("max_failures", 5),
6669
use_vision=config.get("use_vision", True),
6770
memory_interval=config.get("memory_interval", 10),
68-
planner_interval=config.get("planner_interval", 1)
71+
planner_interval=config.get("planner_interval", 1),
72+
tool_calling_method=tool_calling_method
6973
)
7074

7175
result = await agent.run(max_steps=config.get("max_steps", 5))

compose.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@ services:
1717
browser:
1818
depends_on:
1919
- db
20-
image: browseruse/browser-use:pr-2
20+
image: browseruse/browser-use:pr-3
2121
environment:
22+
# - OLLAMA_HOST=http://host.docker.internal:11434
23+
- OLLAMA_HOST=${OLLAMA_HOST}
2224
- ERROR_WEBHOOK_URL=http://localhost:3000
2325
- NOTIFY_WEBHOOK_URL=http://localhost:3000
2426
- METRICS_WEBHOOK_URL=http://localhost:3000
@@ -35,6 +37,7 @@ services:
3537
- CHROME_PERSISTENT_SESSION=true
3638
- RESOLUTION_WIDTH=1920
3739
- RESOLUTION_HEIGHT=1080
40+
- APP_PORT=9000
3841
ports:
3942
- "9000:8000"
4043
volumes:

server.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,10 @@ async def run_agent(
7575

7676
# Initialize browser
7777
browser = Browser(config=browser_config)
78-
78+
tool_calling_method = "auto"
79+
if "deepseek-r1" in request.llm_config.model_name:
80+
tool_calling_method = "json_mode"
81+
7982
# Initialize and run agent
8083
agent = Agent(
8184
task=request.task,
@@ -85,7 +88,8 @@ async def run_agent(
8588
generate_gif=request.generate_gif,
8689
max_failures=request.max_failures,
8790
memory_interval=request.memory_interval,
88-
planner_interval=request.planner_interval
91+
planner_interval=request.planner_interval,
92+
tool_calling_method=tool_calling_method
8993
)
9094

9195
result = await agent.run(max_steps=request.max_steps)

settings.py

Lines changed: 78 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,12 @@
66
from dotenv import load_dotenv
77
from pydantic import BaseModel, Field
88
from langchain_openai import ChatOpenAI, AzureChatOpenAI
9+
from langchain_ollama import ChatOllama
910
from pydantic import SecretStr
1011
from fastapi import HTTPException
1112
from logging_config import setup_logging, log_info, log_error, log_debug, log_warning
13+
from langchain_core.messages import BaseMessage, AIMessage
14+
from langchain_core.runnables import RunnableConfig
1215

1316
# Logging configuration
1417
logger = logging.getLogger('browser-use.settings')
@@ -26,24 +29,21 @@ class ModelConfig(BaseModel):
2629
api_key: Optional[str] = Field(None, description="API key for the provider (if needed)")
2730
azure_endpoint: Optional[str] = Field(None, description="Endpoint for Azure OpenAI (if provider=azure)")
2831
azure_api_version: Optional[str] = Field(None, description="Azure OpenAI API version (if provider=azure)")
29-
temperature: float = Field(0.0, description="Generation temperature (0.0 to 1.0)")
32+
temperature: float = Field(0.5, description="Generation temperature (0.0 to 1.0)")
33+
base_url: Optional[str] = Field(None, description="api base url")
3034

3135

32-
# Configurações do banco de dados
3336
SQLALCHEMY_DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./browser_use.db")
3437
engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False})
3538
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
3639

37-
# Configurações da API
3840
API_HOST = os.getenv("API_HOST", "0.0.0.0")
39-
API_PORT = int(os.getenv("API_PORT", "8000"))
41+
API_PORT = int(os.getenv("API_PORT", "9000"))
4042
API_DEBUG = os.getenv("API_DEBUG", "False").lower() == "true"
4143

42-
# Configurações do OpenAI
4344
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
4445
OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-3.5-turbo")
4546

46-
# Configurações do navegador
4747
BROWSER_HEADLESS = os.getenv("BROWSER_HEADLESS", "True").lower() == "true"
4848
BROWSER_TIMEOUT = int(os.getenv("BROWSER_TIMEOUT", "30000"))
4949

@@ -95,6 +95,19 @@ def get_llm(model_config: ModelConfig):
9595
azure_endpoint=model_config.azure_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT", ""),
9696
api_version=model_config.azure_api_version or "2024-10-21"
9797
)
98+
elif provider == "ollama":
99+
if "deepseek-r1" in model_config.model_name :
100+
log_info(logger, "initializing special provider for ollama deepseek-r1")
101+
return DeepSeekR1ChatOllama(
102+
model=model_config.model_name,
103+
temperature=model_config.temperature,
104+
# num_ctx=32000,
105+
base_url=os.getenv("OLLAMA_HOST")
106+
)
107+
else:
108+
return ChatOllama(
109+
model=model_config.model_name
110+
)
98111
else:
99112
raise ValueError(f"Unsupported provider: {provider}")
100113
except Exception as e:
@@ -104,3 +117,62 @@ def get_llm(model_config: ModelConfig):
104117
"error": str(e)
105118
}, exc_info=True)
106119
raise HTTPException(status_code=500, detail=f"Error initializing LLM: {str(e)}")
120+
121+
class DeepSeekR1ChatOllama(ChatOllama):
122+
"""Custom chat model for DeepSeek-R1."""
123+
124+
def invoke(
125+
self,
126+
input: List[BaseMessage],
127+
config: Optional[RunnableConfig] = None,
128+
**kwargs: Any,
129+
) -> AIMessage:
130+
"""Invoke the chat model with DeepSeek-R1 specific processing."""
131+
org_ai_message = super().invoke(input, config, **kwargs)
132+
org_content = org_ai_message.content
133+
134+
# Extract reasoning content and main content
135+
org_content = str(org_ai_message.content)
136+
if "</think>" in org_content:
137+
parts = org_content.split("</think>")
138+
reasoning_content = parts[0].replace("<think>", "").strip()
139+
content = parts[1].strip()
140+
141+
# Remove JSON Response tag if present
142+
if "**JSON Response:**" in content:
143+
content = content.split("**JSON Response:**")[-1].strip()
144+
145+
# Create AIMessage with extra attributes
146+
message = AIMessage(content=content)
147+
setattr(message, "reasoning_content", reasoning_content)
148+
return message
149+
150+
return AIMessage(content=org_ai_message.content)
151+
152+
async def ainvoke(
153+
self,
154+
input: List[BaseMessage],
155+
config: Optional[RunnableConfig] = None,
156+
**kwargs: Any,
157+
) -> AIMessage:
158+
"""Async invoke the chat model with DeepSeek-R1 specific processing."""
159+
org_ai_message = await super().ainvoke(input, config, **kwargs)
160+
org_content = org_ai_message.content
161+
162+
# Extract reasoning content and main content
163+
org_content = str(org_ai_message.content)
164+
if "</think>" in org_content:
165+
parts = org_content.split("</think>")
166+
reasoning_content = parts[0].replace("<think>", "").strip()
167+
content = parts[1].strip()
168+
169+
# Remove JSON Response tag if present
170+
if "**JSON Response:**" in content:
171+
content = content.split("**JSON Response:**")[-1].strip()
172+
173+
# Create AIMessage with extra attributes
174+
message = AIMessage(content=content)
175+
setattr(message, "reasoning_content", reasoning_content)
176+
return message
177+
178+
return AIMessage(content=org_ai_message.content)

0 commit comments

Comments
 (0)