66from dotenv import load_dotenv
77from pydantic import BaseModel , Field
88from langchain_openai import ChatOpenAI , AzureChatOpenAI
9+ from langchain_ollama import ChatOllama
910from pydantic import SecretStr
1011from fastapi import HTTPException
1112from logging_config import setup_logging , log_info , log_error , log_debug , log_warning
13+ from langchain_core .messages import BaseMessage , AIMessage
14+ from langchain_core .runnables import RunnableConfig
1215
1316# Logging configuration
1417logger = logging .getLogger ('browser-use.settings' )
@@ -26,24 +29,21 @@ class ModelConfig(BaseModel):
2629 api_key : Optional [str ] = Field (None , description = "API key for the provider (if needed)" )
2730 azure_endpoint : Optional [str ] = Field (None , description = "Endpoint for Azure OpenAI (if provider=azure)" )
2831 azure_api_version : Optional [str ] = Field (None , description = "Azure OpenAI API version (if provider=azure)" )
29- temperature : float = Field (0.0 , description = "Generation temperature (0.0 to 1.0)" )
32+ temperature : float = Field (0.5 , description = "Generation temperature (0.0 to 1.0)" )
33+ base_url : Optional [str ] = Field (None , description = "api base url" )
3034
3135
32- # Configurações do banco de dados
3336SQLALCHEMY_DATABASE_URL = os .getenv ("DATABASE_URL" , "sqlite:///./browser_use.db" )
3437engine = create_engine (SQLALCHEMY_DATABASE_URL , connect_args = {"check_same_thread" : False })
3538SessionLocal = sessionmaker (autocommit = False , autoflush = False , bind = engine )
3639
37- # Configurações da API
3840API_HOST = os .getenv ("API_HOST" , "0.0.0.0" )
39- API_PORT = int (os .getenv ("API_PORT" , "8000 " ))
41+ API_PORT = int (os .getenv ("API_PORT" , "9000 " ))
4042API_DEBUG = os .getenv ("API_DEBUG" , "False" ).lower () == "true"
4143
42- # Configurações do OpenAI
4344OPENAI_API_KEY = os .getenv ("OPENAI_API_KEY" )
4445OPENAI_MODEL = os .getenv ("OPENAI_MODEL" , "gpt-3.5-turbo" )
4546
46- # Configurações do navegador
4747BROWSER_HEADLESS = os .getenv ("BROWSER_HEADLESS" , "True" ).lower () == "true"
4848BROWSER_TIMEOUT = int (os .getenv ("BROWSER_TIMEOUT" , "30000" ))
4949
@@ -95,6 +95,19 @@ def get_llm(model_config: ModelConfig):
9595 azure_endpoint = model_config .azure_endpoint or os .getenv ("AZURE_OPENAI_ENDPOINT" , "" ),
9696 api_version = model_config .azure_api_version or "2024-10-21"
9797 )
98+ elif provider == "ollama" :
99+ if "deepseek-r1" in model_config .model_name :
100+ log_info (logger , "initializing special provider for ollama deepseek-r1" )
101+ return DeepSeekR1ChatOllama (
102+ model = model_config .model_name ,
103+ temperature = model_config .temperature ,
104+ # num_ctx=32000,
105+ base_url = os .getenv ("OLLAMA_HOST" )
106+ )
107+ else :
108+ return ChatOllama (
109+ model = model_config .model_name
110+ )
98111 else :
99112 raise ValueError (f"Unsupported provider: { provider } " )
100113 except Exception as e :
@@ -104,3 +117,62 @@ def get_llm(model_config: ModelConfig):
104117 "error" : str (e )
105118 }, exc_info = True )
106119 raise HTTPException (status_code = 500 , detail = f"Error initializing LLM: { str (e )} " )
120+
121+ class DeepSeekR1ChatOllama (ChatOllama ):
122+ """Custom chat model for DeepSeek-R1."""
123+
124+ def invoke (
125+ self ,
126+ input : List [BaseMessage ],
127+ config : Optional [RunnableConfig ] = None ,
128+ ** kwargs : Any ,
129+ ) -> AIMessage :
130+ """Invoke the chat model with DeepSeek-R1 specific processing."""
131+ org_ai_message = super ().invoke (input , config , ** kwargs )
132+ org_content = org_ai_message .content
133+
134+ # Extract reasoning content and main content
135+ org_content = str (org_ai_message .content )
136+ if "</think>" in org_content :
137+ parts = org_content .split ("</think>" )
138+ reasoning_content = parts [0 ].replace ("<think>" , "" ).strip ()
139+ content = parts [1 ].strip ()
140+
141+ # Remove JSON Response tag if present
142+ if "**JSON Response:**" in content :
143+ content = content .split ("**JSON Response:**" )[- 1 ].strip ()
144+
145+ # Create AIMessage with extra attributes
146+ message = AIMessage (content = content )
147+ setattr (message , "reasoning_content" , reasoning_content )
148+ return message
149+
150+ return AIMessage (content = org_ai_message .content )
151+
152+ async def ainvoke (
153+ self ,
154+ input : List [BaseMessage ],
155+ config : Optional [RunnableConfig ] = None ,
156+ ** kwargs : Any ,
157+ ) -> AIMessage :
158+ """Async invoke the chat model with DeepSeek-R1 specific processing."""
159+ org_ai_message = await super ().ainvoke (input , config , ** kwargs )
160+ org_content = org_ai_message .content
161+
162+ # Extract reasoning content and main content
163+ org_content = str (org_ai_message .content )
164+ if "</think>" in org_content :
165+ parts = org_content .split ("</think>" )
166+ reasoning_content = parts [0 ].replace ("<think>" , "" ).strip ()
167+ content = parts [1 ].strip ()
168+
169+ # Remove JSON Response tag if present
170+ if "**JSON Response:**" in content :
171+ content = content .split ("**JSON Response:**" )[- 1 ].strip ()
172+
173+ # Create AIMessage with extra attributes
174+ message = AIMessage (content = content )
175+ setattr (message , "reasoning_content" , reasoning_content )
176+ return message
177+
178+ return AIMessage (content = org_ai_message .content )
0 commit comments