diff --git a/workflows/.env.example b/workflows/.env.example
index 3d6a3f43..9b929fa2 100644
--- a/workflows/.env.example
+++ b/workflows/.env.example
@@ -1,2 +1,42 @@
# We support all langchain models, openai only for demo purposes
-OPENAI_API_KEY=
\ No newline at end of file
+LLM_PROVIDER="openai"
+MODEL_NAME="gpt-4o"
+
+OPENAI_ENDPOINT=https://api.openai.com/v1
+OPENAI_API_KEY=
+
+ANTHROPIC_API_KEY=
+ANTHROPIC_ENDPOINT=https://api.anthropic.com
+
+GOOGLE_API_KEY=
+
+AZURE_OPENAI_ENDPOINT=
+AZURE_OPENAI_API_KEY=
+AZURE_OPENAI_API_VERSION=2025-01-01-preview
+
+DEEPSEEK_ENDPOINT=https://api.deepseek.com
+DEEPSEEK_API_KEY=
+
+MISTRAL_API_KEY=
+MISTRAL_ENDPOINT=https://api.mistral.ai/v1
+
+OLLAMA_ENDPOINT=http://localhost:11434
+
+ALIBABA_ENDPOINT=https://dashscope.aliyuncs.com/compatible-mode/v1
+ALIBABA_API_KEY=
+
+MOONSHOT_ENDPOINT=https://api.moonshot.cn/v1
+MOONSHOT_API_KEY=
+
+UNBOUND_ENDPOINT=https://api.getunbound.ai
+UNBOUND_API_KEY=
+
+SiliconFLOW_ENDPOINT=https://api.siliconflow.cn/v1/
+SiliconFLOW_API_KEY=
+
+IBM_ENDPOINT=https://us-south.ml.cloud.ibm.com
+IBM_API_KEY=
+IBM_PROJECT_ID=
+
+GROK_ENDPOINT="https://api.x.ai/v1"
+GROK_API_KEY=
\ No newline at end of file
diff --git a/workflows/cli.py b/workflows/cli.py
index 6d7193a6..6a149e28 100644
--- a/workflows/cli.py
+++ b/workflows/cli.py
@@ -1,3 +1,5 @@
+from dotenv import load_dotenv
+load_dotenv()
import asyncio
import json
import os
@@ -5,18 +7,18 @@
import tempfile # For temporary file handling
import webbrowser
from pathlib import Path
+import os
import typer
from browser_use.browser.browser import Browser
-# Assuming OPENAI_API_KEY is set in the environment
-from langchain_openai import ChatOpenAI
-
from workflow_use.builder.service import BuilderService
from workflow_use.controller.service import WorkflowController
from workflow_use.mcp.service import get_mcp_server
from workflow_use.recorder.service import RecordingService # Added import
from workflow_use.workflow.service import Workflow
+from workflow_use.llm.llm_provider import get_llm_model
+from workflow_use.llm.config import model_names
# Placeholder for recorder functionality
# from src.recorder.service import RecorderService
@@ -31,15 +33,27 @@
# Default LLM instance to None
llm_instance = None
try:
- llm_instance = ChatOpenAI(model='gpt-4o')
- page_extraction_llm = ChatOpenAI(model='gpt-4o-mini')
+ # Get provider and model name from environment or default to openai
+ provider = os.getenv("LLM_PROVIDER", "openai").lower()
+ model_name = os.getenv("MODEL_NAME", "")
+
+ # If no model name specified, prompt user
+ if not model_name:
+ typer.echo(f"Available models for {provider}:")
+ typer.echo(f"{model_names[provider]}")
+
+ model_name = typer.prompt(f"model name for {provider}, default=", default=model_names[provider][0])
+ os.environ["MODEL_NAME"] = model_name
+
+ # Initialize LLM with selected provider and model
+ llm_instance = get_llm_model(provider, model_name=model_name)
+ page_extraction_llm = get_llm_model(provider, model_name=model_name)
+
except Exception as e:
- typer.secho(f'Error initializing LLM: {e}. Would you like to set your OPENAI_API_KEY?', fg=typer.colors.RED)
- set_openai_api_key = input('Set OPENAI_API_KEY? (y/n): ')
- if set_openai_api_key.lower() == 'y':
- os.environ['OPENAI_API_KEY'] = input('Enter your OPENAI_API_KEY: ')
- llm_instance = ChatOpenAI(model='gpt-4o')
- page_extraction_llm = ChatOpenAI(model='gpt-4o-mini')
+ typer.secho(
+ f"Error initializing LLM: {e}. Would you like to set your API key?",
+ fg=typer.colors.RED,
+ )
builder_service = BuilderService(llm=llm_instance) if llm_instance else None
# recorder_service = RecorderService() # Placeholder
diff --git a/workflows/requirements.txt b/workflows/requirements.txt
new file mode 100644
index 00000000..47c5f0cd
Binary files /dev/null and b/workflows/requirements.txt differ
diff --git a/workflows/workflow_use/llm/config.py b/workflows/workflow_use/llm/config.py
new file mode 100644
index 00000000..4c92d427
--- /dev/null
+++ b/workflows/workflow_use/llm/config.py
@@ -0,0 +1,87 @@
+# Predefined model names for common providers
+model_names = {
+ "anthropic": ["claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20240620", "claude-3-opus-20240229"],
+ "openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo", "o3-mini"],
+ "deepseek": ["deepseek-chat", "deepseek-reasoner"],
+ "google": ["gemini-2.0-flash", "gemini-2.0-flash-thinking-exp", "gemini-1.5-flash-latest",
+ "gemini-1.5-flash-8b-latest", "gemini-2.0-flash-thinking-exp-01-21", "gemini-2.0-pro-exp-02-05",
+ "gemini-2.5-pro-preview-03-25", "gemini-2.5-flash-preview-04-17"],
+ "ollama": ["qwen2.5:7b", "qwen2.5:14b", "qwen2.5:32b", "qwen2.5-coder:14b", "qwen2.5-coder:32b", "llama2:7b",
+ "deepseek-r1:14b", "deepseek-r1:32b"],
+ "azure_openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"],
+ "mistral": ["pixtral-large-latest", "mistral-large-latest", "mistral-small-latest", "ministral-8b-latest"],
+ "alibaba": ["qwen-plus", "qwen-max", "qwen-vl-max", "qwen-vl-plus", "qwen-turbo", "qwen-long"],
+ "moonshot": ["moonshot-v1-32k-vision-preview", "moonshot-v1-8k-vision-preview"],
+ "unbound": ["gemini-2.0-flash", "gpt-4o-mini", "gpt-4o", "gpt-4.5-preview"],
+ "grok": [
+ "grok-3",
+ "grok-3-fast",
+ "grok-3-mini",
+ "grok-3-mini-fast",
+ "grok-2-vision",
+ "grok-2-image",
+ "grok-2",
+ ],
+ "siliconflow": [
+ "deepseek-ai/DeepSeek-R1",
+ "deepseek-ai/DeepSeek-V3",
+ "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
+ "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B",
+ "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
+ "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
+ "deepseek-ai/DeepSeek-V2.5",
+ "deepseek-ai/deepseek-vl2",
+ "Qwen/Qwen2.5-72B-Instruct-128K",
+ "Qwen/Qwen2.5-72B-Instruct",
+ "Qwen/Qwen2.5-32B-Instruct",
+ "Qwen/Qwen2.5-14B-Instruct",
+ "Qwen/Qwen2.5-7B-Instruct",
+ "Qwen/Qwen2.5-Coder-32B-Instruct",
+ "Qwen/Qwen2.5-Coder-7B-Instruct",
+ "Qwen/Qwen2-7B-Instruct",
+ "Qwen/Qwen2-1.5B-Instruct",
+ "Qwen/QwQ-32B-Preview",
+ "Qwen/Qwen2-VL-72B-Instruct",
+ "Qwen/Qwen2.5-VL-32B-Instruct",
+ "Qwen/Qwen2.5-VL-72B-Instruct",
+ "TeleAI/TeleChat2",
+ "THUDM/glm-4-9b-chat",
+ "Vendor-A/Qwen/Qwen2.5-72B-Instruct",
+ "internlm/internlm2_5-7b-chat",
+ "internlm/internlm2_5-20b-chat",
+ "Pro/Qwen/Qwen2.5-7B-Instruct",
+ "Pro/Qwen/Qwen2-7B-Instruct",
+ "Pro/Qwen/Qwen2-1.5B-Instruct",
+ "Pro/THUDM/chatglm3-6b",
+ "Pro/THUDM/glm-4-9b-chat",
+ ],
+ "ibm": ["ibm/granite-vision-3.1-2b-preview", "meta-llama/llama-4-maverick-17b-128e-instruct-fp8",
+ "meta-llama/llama-3-2-90b-vision-instruct"],
+ "modelscope":[
+ "Qwen/Qwen2.5-Coder-32B-Instruct",
+ "Qwen/Qwen2.5-Coder-14B-Instruct",
+ "Qwen/Qwen2.5-Coder-7B-Instruct",
+ "Qwen/Qwen2.5-72B-Instruct",
+ "Qwen/Qwen2.5-32B-Instruct",
+ "Qwen/Qwen2.5-14B-Instruct",
+ "Qwen/Qwen2.5-7B-Instruct",
+ "Qwen/QwQ-32B-Preview",
+ "Qwen/Qwen2.5-VL-3B-Instruct",
+ "Qwen/Qwen2.5-VL-7B-Instruct",
+ "Qwen/Qwen2.5-VL-32B-Instruct",
+ "Qwen/Qwen2.5-VL-72B-Instruct",
+ "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
+ "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B",
+ "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
+ "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
+ "deepseek-ai/DeepSeek-R1",
+ "deepseek-ai/DeepSeek-V3",
+ "Qwen/Qwen3-1.7B",
+ "Qwen/Qwen3-4B",
+ "Qwen/Qwen3-8B",
+ "Qwen/Qwen3-14B",
+ "Qwen/Qwen3-30B-A3B",
+ "Qwen/Qwen3-32B",
+ "Qwen/Qwen3-235B-A22B",
+ ],
+}
diff --git a/workflows/workflow_use/llm/llm_provider.py b/workflows/workflow_use/llm/llm_provider.py
new file mode 100644
index 00000000..d878bab6
--- /dev/null
+++ b/workflows/workflow_use/llm/llm_provider.py
@@ -0,0 +1,351 @@
+from openai import OpenAI
+import pdb
+from langchain_openai import ChatOpenAI
+from langchain_core.globals import get_llm_cache
+from langchain_core.language_models.base import (
+ BaseLanguageModel,
+ LangSmithParams,
+ LanguageModelInput,
+)
+import os
+from langchain_core.load import dumpd, dumps
+from langchain_core.messages import (
+ AIMessage,
+ SystemMessage,
+ AnyMessage,
+ BaseMessage,
+ BaseMessageChunk,
+ HumanMessage,
+ convert_to_messages,
+ message_chunk_to_message,
+)
+from langchain_core.outputs import (
+ ChatGeneration,
+ ChatGenerationChunk,
+ ChatResult,
+ LLMResult,
+ RunInfo,
+)
+from langchain_ollama import ChatOllama
+from langchain_core.output_parsers.base import OutputParserLike
+from langchain_core.runnables import Runnable, RunnableConfig
+from langchain_core.tools import BaseTool
+
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Literal,
+ Optional,
+ Union,
+ cast, List,
+)
+from langchain_anthropic import ChatAnthropic
+from langchain_mistralai import ChatMistralAI
+from langchain_google_genai import ChatGoogleGenerativeAI
+from langchain_ollama import ChatOllama
+from langchain_openai import AzureChatOpenAI, ChatOpenAI
+from langchain_ibm import ChatWatsonx
+from langchain_aws import ChatBedrock
+from pydantic import SecretStr
+
+
+class DeepSeekR1ChatOpenAI(ChatOpenAI):
+
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ super().__init__(*args, **kwargs)
+ self.client = OpenAI(
+ base_url=kwargs.get("base_url"),
+ api_key=kwargs.get("api_key")
+ )
+
+ async def ainvoke(
+ self,
+ input: LanguageModelInput,
+ config: Optional[RunnableConfig] = None,
+ *,
+ stop: Optional[list[str]] = None,
+ **kwargs: Any,
+ ) -> AIMessage:
+ message_history = []
+ for input_ in input:
+ if isinstance(input_, SystemMessage):
+ message_history.append({"role": "system", "content": input_.content})
+ elif isinstance(input_, AIMessage):
+ message_history.append({"role": "assistant", "content": input_.content})
+ else:
+ message_history.append({"role": "user", "content": input_.content})
+
+ response = self.client.chat.completions.create(
+ model=self.model_name,
+ messages=message_history
+ )
+
+ reasoning_content = response.choices[0].message.reasoning_content
+ content = response.choices[0].message.content
+ return AIMessage(content=content, reasoning_content=reasoning_content)
+
+ def invoke(
+ self,
+ input: LanguageModelInput,
+ config: Optional[RunnableConfig] = None,
+ *,
+ stop: Optional[list[str]] = None,
+ **kwargs: Any,
+ ) -> AIMessage:
+ message_history = []
+ for input_ in input:
+ if isinstance(input_, SystemMessage):
+ message_history.append({"role": "system", "content": input_.content})
+ elif isinstance(input_, AIMessage):
+ message_history.append({"role": "assistant", "content": input_.content})
+ else:
+ message_history.append({"role": "user", "content": input_.content})
+
+ response = self.client.chat.completions.create(
+ model=self.model_name,
+ messages=message_history
+ )
+
+ reasoning_content = response.choices[0].message.reasoning_content
+ content = response.choices[0].message.content
+ return AIMessage(content=content, reasoning_content=reasoning_content)
+
+
+class DeepSeekR1ChatOllama(ChatOllama):
+
+ async def ainvoke(
+ self,
+ input: LanguageModelInput,
+ config: Optional[RunnableConfig] = None,
+ *,
+ stop: Optional[list[str]] = None,
+ **kwargs: Any,
+ ) -> AIMessage:
+ org_ai_message = await super().ainvoke(input=input)
+ org_content = org_ai_message.content
+ reasoning_content = org_content.split("")[0].replace("", "")
+ content = org_content.split("")[1]
+ if "**JSON Response:**" in content:
+ content = content.split("**JSON Response:**")[-1]
+ return AIMessage(content=content, reasoning_content=reasoning_content)
+
+ def invoke(
+ self,
+ input: LanguageModelInput,
+ config: Optional[RunnableConfig] = None,
+ *,
+ stop: Optional[list[str]] = None,
+ **kwargs: Any,
+ ) -> AIMessage:
+ org_ai_message = super().invoke(input=input)
+ org_content = org_ai_message.content
+ reasoning_content = org_content.split("")[0].replace("", "")
+ content = org_content.split("")[1]
+ if "**JSON Response:**" in content:
+ content = content.split("**JSON Response:**")[-1]
+ return AIMessage(content=content, reasoning_content=reasoning_content)
+
+
+def get_llm_model(provider: str, **kwargs):
+ """
+ Get LLM model
+ :param provider: LLM provider
+ :param kwargs:
+ :return:
+ """
+ if provider not in ["ollama", "bedrock"]:
+ env_var = f"{provider.upper()}_API_KEY"
+ api_key = kwargs.get("api_key", "") or os.getenv(env_var, "")
+ if not api_key:
+ error_msg = f"🔑 Please set the `{env_var}` environment variable."
+ raise ValueError(error_msg)
+ kwargs["api_key"] = api_key
+
+ if provider == "anthropic":
+ if not kwargs.get("base_url", ""):
+ base_url = "https://api.anthropic.com"
+ else:
+ base_url = kwargs.get("base_url")
+
+ return ChatAnthropic(
+ model=kwargs.get("model_name", "claude-3-5-sonnet-20241022"),
+ temperature=kwargs.get("temperature", 0.0),
+ base_url=base_url,
+ api_key=api_key,
+ )
+ elif provider == 'mistral':
+ if not kwargs.get("base_url", ""):
+ base_url = os.getenv("MISTRAL_ENDPOINT", "https://api.mistral.ai/v1")
+ else:
+ base_url = kwargs.get("base_url")
+ if not kwargs.get("api_key", ""):
+ api_key = os.getenv("MISTRAL_API_KEY", "")
+ else:
+ api_key = kwargs.get("api_key")
+
+ return ChatMistralAI(
+ model=kwargs.get("model_name", "mistral-large-latest"),
+ temperature=kwargs.get("temperature", 0.0),
+ base_url=base_url,
+ api_key=api_key,
+ )
+ elif provider == "openai":
+ if not kwargs.get("base_url", ""):
+ base_url = os.getenv("OPENAI_ENDPOINT", "https://api.openai.com/v1")
+ else:
+ base_url = kwargs.get("base_url")
+
+ return ChatOpenAI(
+ model=kwargs.get("model_name", "gpt-4o"),
+ temperature=kwargs.get("temperature", 0.0),
+ base_url=base_url,
+ api_key=api_key,
+ )
+ elif provider == "grok":
+ if not kwargs.get("base_url", ""):
+ base_url = os.getenv("GROK_ENDPOINT", "https://api.x.ai/v1")
+ else:
+ base_url = kwargs.get("base_url")
+
+ return ChatOpenAI(
+ model=kwargs.get("model_name", "grok-3"),
+ temperature=kwargs.get("temperature", 0.0),
+ base_url=base_url,
+ api_key=api_key,
+ )
+ elif provider == "deepseek":
+ if not kwargs.get("base_url", ""):
+ base_url = os.getenv("DEEPSEEK_ENDPOINT", "")
+ else:
+ base_url = kwargs.get("base_url")
+
+ if kwargs.get("model_name", "deepseek-chat") == "deepseek-reasoner":
+ return DeepSeekR1ChatOpenAI(
+ model=kwargs.get("model_name", "deepseek-reasoner"),
+ temperature=kwargs.get("temperature", 0.0),
+ base_url=base_url,
+ api_key=api_key,
+ )
+ else:
+ return ChatOpenAI(
+ model=kwargs.get("model_name", "deepseek-chat"),
+ temperature=kwargs.get("temperature", 0.0),
+ base_url=base_url,
+ api_key=api_key,
+ )
+ elif provider == "google":
+ return ChatGoogleGenerativeAI(
+ model=kwargs.get("model_name", "gemini-2.0-flash-exp"),
+ temperature=kwargs.get("temperature", 0.0),
+ api_key=api_key,
+ )
+ elif provider == "ollama":
+ if not kwargs.get("base_url", ""):
+ base_url = os.getenv("OLLAMA_ENDPOINT", "http://localhost:11434")
+ else:
+ base_url = kwargs.get("base_url")
+
+ if "deepseek-r1" in kwargs.get("model_name", "qwen2.5:7b"):
+ return DeepSeekR1ChatOllama(
+ model=kwargs.get("model_name", "deepseek-r1:14b"),
+ temperature=kwargs.get("temperature", 0.0),
+ num_ctx=kwargs.get("num_ctx", 32000),
+ base_url=base_url,
+ )
+ else:
+ return ChatOllama(
+ model=kwargs.get("model_name", "qwen2.5:7b"),
+ temperature=kwargs.get("temperature", 0.0),
+ num_ctx=kwargs.get("num_ctx", 32000),
+ num_predict=kwargs.get("num_predict", 1024),
+ base_url=base_url,
+ )
+ elif provider == "azure_openai":
+ if not kwargs.get("base_url", ""):
+ base_url = os.getenv("AZURE_OPENAI_ENDPOINT", "")
+ else:
+ base_url = kwargs.get("base_url")
+ api_version = kwargs.get("api_version", "") or os.getenv("AZURE_OPENAI_API_VERSION", "2025-01-01-preview")
+ return AzureChatOpenAI(
+ model=kwargs.get("model_name", "gpt-4o"),
+ temperature=kwargs.get("temperature", 0.0),
+ api_version=api_version,
+ azure_endpoint=base_url,
+ api_key=api_key,
+ )
+ elif provider == "alibaba":
+ if not kwargs.get("base_url", ""):
+ base_url = os.getenv("ALIBABA_ENDPOINT", "https://dashscope.aliyuncs.com/compatible-mode/v1")
+ else:
+ base_url = kwargs.get("base_url")
+
+ return ChatOpenAI(
+ model=kwargs.get("model_name", "qwen-plus"),
+ temperature=kwargs.get("temperature", 0.0),
+ base_url=base_url,
+ api_key=api_key,
+ )
+ elif provider == "ibm":
+ parameters = {
+ "temperature": kwargs.get("temperature", 0.0),
+ "max_tokens": kwargs.get("num_ctx", 32000)
+ }
+ if not kwargs.get("base_url", ""):
+ base_url = os.getenv("IBM_ENDPOINT", "https://us-south.ml.cloud.ibm.com")
+ else:
+ base_url = kwargs.get("base_url")
+
+ return ChatWatsonx(
+ model_id=kwargs.get("model_name", "ibm/granite-vision-3.1-2b-preview"),
+ url=base_url,
+ project_id=os.getenv("IBM_PROJECT_ID"),
+ apikey=os.getenv("IBM_API_KEY"),
+ params=parameters
+ )
+ elif provider == "moonshot":
+ return ChatOpenAI(
+ model=kwargs.get("model_name", "moonshot-v1-32k-vision-preview"),
+ temperature=kwargs.get("temperature", 0.0),
+ base_url=os.getenv("MOONSHOT_ENDPOINT"),
+ api_key=os.getenv("MOONSHOT_API_KEY"),
+ )
+ elif provider == "unbound":
+ return ChatOpenAI(
+ model=kwargs.get("model_name", "gpt-4o-mini"),
+ temperature=kwargs.get("temperature", 0.0),
+ base_url=os.getenv("UNBOUND_ENDPOINT", "https://api.getunbound.ai"),
+ api_key=api_key,
+ )
+ elif provider == "siliconflow":
+ if not kwargs.get("api_key", ""):
+ api_key = os.getenv("SiliconFLOW_API_KEY", "")
+ else:
+ api_key = kwargs.get("api_key")
+ if not kwargs.get("base_url", ""):
+ base_url = os.getenv("SiliconFLOW_ENDPOINT", "")
+ else:
+ base_url = kwargs.get("base_url")
+ return ChatOpenAI(
+ api_key=api_key,
+ base_url=base_url,
+ model_name=kwargs.get("model_name", "Qwen/QwQ-32B"),
+ temperature=kwargs.get("temperature", 0.0),
+ )
+ elif provider == "modelscope":
+ if not kwargs.get("api_key", ""):
+ api_key = os.getenv("MODELSCOPE_API_KEY", "")
+ else:
+ api_key = kwargs.get("api_key")
+ if not kwargs.get("base_url", ""):
+ base_url = os.getenv("MODELSCOPE_ENDPOINT", "")
+ else:
+ base_url = kwargs.get("base_url")
+ return ChatOpenAI(
+ api_key=api_key,
+ base_url=base_url,
+ model_name=kwargs.get("model_name", "Qwen/QwQ-32B"),
+ temperature=kwargs.get("temperature", 0.0),
+ )
+ else:
+ raise ValueError(f"Unsupported provider: {provider}")