diff --git a/workflows/.env.example b/workflows/.env.example index 3d6a3f43..9b929fa2 100644 --- a/workflows/.env.example +++ b/workflows/.env.example @@ -1,2 +1,42 @@ # We support all langchain models, openai only for demo purposes -OPENAI_API_KEY= \ No newline at end of file +LLM_PROVIDER="openai" +MODEL_NAME="gpt-4o" + +OPENAI_ENDPOINT=https://api.openai.com/v1 +OPENAI_API_KEY= + +ANTHROPIC_API_KEY= +ANTHROPIC_ENDPOINT=https://api.anthropic.com + +GOOGLE_API_KEY= + +AZURE_OPENAI_ENDPOINT= +AZURE_OPENAI_API_KEY= +AZURE_OPENAI_API_VERSION=2025-01-01-preview + +DEEPSEEK_ENDPOINT=https://api.deepseek.com +DEEPSEEK_API_KEY= + +MISTRAL_API_KEY= +MISTRAL_ENDPOINT=https://api.mistral.ai/v1 + +OLLAMA_ENDPOINT=http://localhost:11434 + +ALIBABA_ENDPOINT=https://dashscope.aliyuncs.com/compatible-mode/v1 +ALIBABA_API_KEY= + +MOONSHOT_ENDPOINT=https://api.moonshot.cn/v1 +MOONSHOT_API_KEY= + +UNBOUND_ENDPOINT=https://api.getunbound.ai +UNBOUND_API_KEY= + +SiliconFLOW_ENDPOINT=https://api.siliconflow.cn/v1/ +SiliconFLOW_API_KEY= + +IBM_ENDPOINT=https://us-south.ml.cloud.ibm.com +IBM_API_KEY= +IBM_PROJECT_ID= + +GROK_ENDPOINT="https://api.x.ai/v1" +GROK_API_KEY= \ No newline at end of file diff --git a/workflows/cli.py b/workflows/cli.py index 6d7193a6..6a149e28 100644 --- a/workflows/cli.py +++ b/workflows/cli.py @@ -1,3 +1,5 @@ +from dotenv import load_dotenv +load_dotenv() import asyncio import json import os @@ -5,18 +7,18 @@ import tempfile # For temporary file handling import webbrowser from pathlib import Path +import os import typer from browser_use.browser.browser import Browser -# Assuming OPENAI_API_KEY is set in the environment -from langchain_openai import ChatOpenAI - from workflow_use.builder.service import BuilderService from workflow_use.controller.service import WorkflowController from workflow_use.mcp.service import get_mcp_server from workflow_use.recorder.service import RecordingService # Added import from workflow_use.workflow.service import Workflow +from workflow_use.llm.llm_provider import get_llm_model +from workflow_use.llm.config import model_names # Placeholder for recorder functionality # from src.recorder.service import RecorderService @@ -31,15 +33,27 @@ # Default LLM instance to None llm_instance = None try: - llm_instance = ChatOpenAI(model='gpt-4o') - page_extraction_llm = ChatOpenAI(model='gpt-4o-mini') + # Get provider and model name from environment or default to openai + provider = os.getenv("LLM_PROVIDER", "openai").lower() + model_name = os.getenv("MODEL_NAME", "") + + # If no model name specified, prompt user + if not model_name: + typer.echo(f"Available models for {provider}:") + typer.echo(f"{model_names[provider]}") + + model_name = typer.prompt(f"model name for {provider}, default=", default=model_names[provider][0]) + os.environ["MODEL_NAME"] = model_name + + # Initialize LLM with selected provider and model + llm_instance = get_llm_model(provider, model_name=model_name) + page_extraction_llm = get_llm_model(provider, model_name=model_name) + except Exception as e: - typer.secho(f'Error initializing LLM: {e}. Would you like to set your OPENAI_API_KEY?', fg=typer.colors.RED) - set_openai_api_key = input('Set OPENAI_API_KEY? (y/n): ') - if set_openai_api_key.lower() == 'y': - os.environ['OPENAI_API_KEY'] = input('Enter your OPENAI_API_KEY: ') - llm_instance = ChatOpenAI(model='gpt-4o') - page_extraction_llm = ChatOpenAI(model='gpt-4o-mini') + typer.secho( + f"Error initializing LLM: {e}. Would you like to set your API key?", + fg=typer.colors.RED, + ) builder_service = BuilderService(llm=llm_instance) if llm_instance else None # recorder_service = RecorderService() # Placeholder diff --git a/workflows/requirements.txt b/workflows/requirements.txt new file mode 100644 index 00000000..47c5f0cd Binary files /dev/null and b/workflows/requirements.txt differ diff --git a/workflows/workflow_use/llm/config.py b/workflows/workflow_use/llm/config.py new file mode 100644 index 00000000..4c92d427 --- /dev/null +++ b/workflows/workflow_use/llm/config.py @@ -0,0 +1,87 @@ +# Predefined model names for common providers +model_names = { + "anthropic": ["claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20240620", "claude-3-opus-20240229"], + "openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo", "o3-mini"], + "deepseek": ["deepseek-chat", "deepseek-reasoner"], + "google": ["gemini-2.0-flash", "gemini-2.0-flash-thinking-exp", "gemini-1.5-flash-latest", + "gemini-1.5-flash-8b-latest", "gemini-2.0-flash-thinking-exp-01-21", "gemini-2.0-pro-exp-02-05", + "gemini-2.5-pro-preview-03-25", "gemini-2.5-flash-preview-04-17"], + "ollama": ["qwen2.5:7b", "qwen2.5:14b", "qwen2.5:32b", "qwen2.5-coder:14b", "qwen2.5-coder:32b", "llama2:7b", + "deepseek-r1:14b", "deepseek-r1:32b"], + "azure_openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"], + "mistral": ["pixtral-large-latest", "mistral-large-latest", "mistral-small-latest", "ministral-8b-latest"], + "alibaba": ["qwen-plus", "qwen-max", "qwen-vl-max", "qwen-vl-plus", "qwen-turbo", "qwen-long"], + "moonshot": ["moonshot-v1-32k-vision-preview", "moonshot-v1-8k-vision-preview"], + "unbound": ["gemini-2.0-flash", "gpt-4o-mini", "gpt-4o", "gpt-4.5-preview"], + "grok": [ + "grok-3", + "grok-3-fast", + "grok-3-mini", + "grok-3-mini-fast", + "grok-2-vision", + "grok-2-image", + "grok-2", + ], + "siliconflow": [ + "deepseek-ai/DeepSeek-R1", + "deepseek-ai/DeepSeek-V3", + "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", + "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B", + "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", + "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", + "deepseek-ai/DeepSeek-V2.5", + "deepseek-ai/deepseek-vl2", + "Qwen/Qwen2.5-72B-Instruct-128K", + "Qwen/Qwen2.5-72B-Instruct", + "Qwen/Qwen2.5-32B-Instruct", + "Qwen/Qwen2.5-14B-Instruct", + "Qwen/Qwen2.5-7B-Instruct", + "Qwen/Qwen2.5-Coder-32B-Instruct", + "Qwen/Qwen2.5-Coder-7B-Instruct", + "Qwen/Qwen2-7B-Instruct", + "Qwen/Qwen2-1.5B-Instruct", + "Qwen/QwQ-32B-Preview", + "Qwen/Qwen2-VL-72B-Instruct", + "Qwen/Qwen2.5-VL-32B-Instruct", + "Qwen/Qwen2.5-VL-72B-Instruct", + "TeleAI/TeleChat2", + "THUDM/glm-4-9b-chat", + "Vendor-A/Qwen/Qwen2.5-72B-Instruct", + "internlm/internlm2_5-7b-chat", + "internlm/internlm2_5-20b-chat", + "Pro/Qwen/Qwen2.5-7B-Instruct", + "Pro/Qwen/Qwen2-7B-Instruct", + "Pro/Qwen/Qwen2-1.5B-Instruct", + "Pro/THUDM/chatglm3-6b", + "Pro/THUDM/glm-4-9b-chat", + ], + "ibm": ["ibm/granite-vision-3.1-2b-preview", "meta-llama/llama-4-maverick-17b-128e-instruct-fp8", + "meta-llama/llama-3-2-90b-vision-instruct"], + "modelscope":[ + "Qwen/Qwen2.5-Coder-32B-Instruct", + "Qwen/Qwen2.5-Coder-14B-Instruct", + "Qwen/Qwen2.5-Coder-7B-Instruct", + "Qwen/Qwen2.5-72B-Instruct", + "Qwen/Qwen2.5-32B-Instruct", + "Qwen/Qwen2.5-14B-Instruct", + "Qwen/Qwen2.5-7B-Instruct", + "Qwen/QwQ-32B-Preview", + "Qwen/Qwen2.5-VL-3B-Instruct", + "Qwen/Qwen2.5-VL-7B-Instruct", + "Qwen/Qwen2.5-VL-32B-Instruct", + "Qwen/Qwen2.5-VL-72B-Instruct", + "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", + "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B", + "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", + "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", + "deepseek-ai/DeepSeek-R1", + "deepseek-ai/DeepSeek-V3", + "Qwen/Qwen3-1.7B", + "Qwen/Qwen3-4B", + "Qwen/Qwen3-8B", + "Qwen/Qwen3-14B", + "Qwen/Qwen3-30B-A3B", + "Qwen/Qwen3-32B", + "Qwen/Qwen3-235B-A22B", + ], +} diff --git a/workflows/workflow_use/llm/llm_provider.py b/workflows/workflow_use/llm/llm_provider.py new file mode 100644 index 00000000..d878bab6 --- /dev/null +++ b/workflows/workflow_use/llm/llm_provider.py @@ -0,0 +1,351 @@ +from openai import OpenAI +import pdb +from langchain_openai import ChatOpenAI +from langchain_core.globals import get_llm_cache +from langchain_core.language_models.base import ( + BaseLanguageModel, + LangSmithParams, + LanguageModelInput, +) +import os +from langchain_core.load import dumpd, dumps +from langchain_core.messages import ( + AIMessage, + SystemMessage, + AnyMessage, + BaseMessage, + BaseMessageChunk, + HumanMessage, + convert_to_messages, + message_chunk_to_message, +) +from langchain_core.outputs import ( + ChatGeneration, + ChatGenerationChunk, + ChatResult, + LLMResult, + RunInfo, +) +from langchain_ollama import ChatOllama +from langchain_core.output_parsers.base import OutputParserLike +from langchain_core.runnables import Runnable, RunnableConfig +from langchain_core.tools import BaseTool + +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Literal, + Optional, + Union, + cast, List, +) +from langchain_anthropic import ChatAnthropic +from langchain_mistralai import ChatMistralAI +from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_ollama import ChatOllama +from langchain_openai import AzureChatOpenAI, ChatOpenAI +from langchain_ibm import ChatWatsonx +from langchain_aws import ChatBedrock +from pydantic import SecretStr + + +class DeepSeekR1ChatOpenAI(ChatOpenAI): + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.client = OpenAI( + base_url=kwargs.get("base_url"), + api_key=kwargs.get("api_key") + ) + + async def ainvoke( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[list[str]] = None, + **kwargs: Any, + ) -> AIMessage: + message_history = [] + for input_ in input: + if isinstance(input_, SystemMessage): + message_history.append({"role": "system", "content": input_.content}) + elif isinstance(input_, AIMessage): + message_history.append({"role": "assistant", "content": input_.content}) + else: + message_history.append({"role": "user", "content": input_.content}) + + response = self.client.chat.completions.create( + model=self.model_name, + messages=message_history + ) + + reasoning_content = response.choices[0].message.reasoning_content + content = response.choices[0].message.content + return AIMessage(content=content, reasoning_content=reasoning_content) + + def invoke( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[list[str]] = None, + **kwargs: Any, + ) -> AIMessage: + message_history = [] + for input_ in input: + if isinstance(input_, SystemMessage): + message_history.append({"role": "system", "content": input_.content}) + elif isinstance(input_, AIMessage): + message_history.append({"role": "assistant", "content": input_.content}) + else: + message_history.append({"role": "user", "content": input_.content}) + + response = self.client.chat.completions.create( + model=self.model_name, + messages=message_history + ) + + reasoning_content = response.choices[0].message.reasoning_content + content = response.choices[0].message.content + return AIMessage(content=content, reasoning_content=reasoning_content) + + +class DeepSeekR1ChatOllama(ChatOllama): + + async def ainvoke( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[list[str]] = None, + **kwargs: Any, + ) -> AIMessage: + org_ai_message = await super().ainvoke(input=input) + org_content = org_ai_message.content + reasoning_content = org_content.split("")[0].replace("", "") + content = org_content.split("")[1] + if "**JSON Response:**" in content: + content = content.split("**JSON Response:**")[-1] + return AIMessage(content=content, reasoning_content=reasoning_content) + + def invoke( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[list[str]] = None, + **kwargs: Any, + ) -> AIMessage: + org_ai_message = super().invoke(input=input) + org_content = org_ai_message.content + reasoning_content = org_content.split("")[0].replace("", "") + content = org_content.split("")[1] + if "**JSON Response:**" in content: + content = content.split("**JSON Response:**")[-1] + return AIMessage(content=content, reasoning_content=reasoning_content) + + +def get_llm_model(provider: str, **kwargs): + """ + Get LLM model + :param provider: LLM provider + :param kwargs: + :return: + """ + if provider not in ["ollama", "bedrock"]: + env_var = f"{provider.upper()}_API_KEY" + api_key = kwargs.get("api_key", "") or os.getenv(env_var, "") + if not api_key: + error_msg = f"🔑 Please set the `{env_var}` environment variable." + raise ValueError(error_msg) + kwargs["api_key"] = api_key + + if provider == "anthropic": + if not kwargs.get("base_url", ""): + base_url = "https://api.anthropic.com" + else: + base_url = kwargs.get("base_url") + + return ChatAnthropic( + model=kwargs.get("model_name", "claude-3-5-sonnet-20241022"), + temperature=kwargs.get("temperature", 0.0), + base_url=base_url, + api_key=api_key, + ) + elif provider == 'mistral': + if not kwargs.get("base_url", ""): + base_url = os.getenv("MISTRAL_ENDPOINT", "https://api.mistral.ai/v1") + else: + base_url = kwargs.get("base_url") + if not kwargs.get("api_key", ""): + api_key = os.getenv("MISTRAL_API_KEY", "") + else: + api_key = kwargs.get("api_key") + + return ChatMistralAI( + model=kwargs.get("model_name", "mistral-large-latest"), + temperature=kwargs.get("temperature", 0.0), + base_url=base_url, + api_key=api_key, + ) + elif provider == "openai": + if not kwargs.get("base_url", ""): + base_url = os.getenv("OPENAI_ENDPOINT", "https://api.openai.com/v1") + else: + base_url = kwargs.get("base_url") + + return ChatOpenAI( + model=kwargs.get("model_name", "gpt-4o"), + temperature=kwargs.get("temperature", 0.0), + base_url=base_url, + api_key=api_key, + ) + elif provider == "grok": + if not kwargs.get("base_url", ""): + base_url = os.getenv("GROK_ENDPOINT", "https://api.x.ai/v1") + else: + base_url = kwargs.get("base_url") + + return ChatOpenAI( + model=kwargs.get("model_name", "grok-3"), + temperature=kwargs.get("temperature", 0.0), + base_url=base_url, + api_key=api_key, + ) + elif provider == "deepseek": + if not kwargs.get("base_url", ""): + base_url = os.getenv("DEEPSEEK_ENDPOINT", "") + else: + base_url = kwargs.get("base_url") + + if kwargs.get("model_name", "deepseek-chat") == "deepseek-reasoner": + return DeepSeekR1ChatOpenAI( + model=kwargs.get("model_name", "deepseek-reasoner"), + temperature=kwargs.get("temperature", 0.0), + base_url=base_url, + api_key=api_key, + ) + else: + return ChatOpenAI( + model=kwargs.get("model_name", "deepseek-chat"), + temperature=kwargs.get("temperature", 0.0), + base_url=base_url, + api_key=api_key, + ) + elif provider == "google": + return ChatGoogleGenerativeAI( + model=kwargs.get("model_name", "gemini-2.0-flash-exp"), + temperature=kwargs.get("temperature", 0.0), + api_key=api_key, + ) + elif provider == "ollama": + if not kwargs.get("base_url", ""): + base_url = os.getenv("OLLAMA_ENDPOINT", "http://localhost:11434") + else: + base_url = kwargs.get("base_url") + + if "deepseek-r1" in kwargs.get("model_name", "qwen2.5:7b"): + return DeepSeekR1ChatOllama( + model=kwargs.get("model_name", "deepseek-r1:14b"), + temperature=kwargs.get("temperature", 0.0), + num_ctx=kwargs.get("num_ctx", 32000), + base_url=base_url, + ) + else: + return ChatOllama( + model=kwargs.get("model_name", "qwen2.5:7b"), + temperature=kwargs.get("temperature", 0.0), + num_ctx=kwargs.get("num_ctx", 32000), + num_predict=kwargs.get("num_predict", 1024), + base_url=base_url, + ) + elif provider == "azure_openai": + if not kwargs.get("base_url", ""): + base_url = os.getenv("AZURE_OPENAI_ENDPOINT", "") + else: + base_url = kwargs.get("base_url") + api_version = kwargs.get("api_version", "") or os.getenv("AZURE_OPENAI_API_VERSION", "2025-01-01-preview") + return AzureChatOpenAI( + model=kwargs.get("model_name", "gpt-4o"), + temperature=kwargs.get("temperature", 0.0), + api_version=api_version, + azure_endpoint=base_url, + api_key=api_key, + ) + elif provider == "alibaba": + if not kwargs.get("base_url", ""): + base_url = os.getenv("ALIBABA_ENDPOINT", "https://dashscope.aliyuncs.com/compatible-mode/v1") + else: + base_url = kwargs.get("base_url") + + return ChatOpenAI( + model=kwargs.get("model_name", "qwen-plus"), + temperature=kwargs.get("temperature", 0.0), + base_url=base_url, + api_key=api_key, + ) + elif provider == "ibm": + parameters = { + "temperature": kwargs.get("temperature", 0.0), + "max_tokens": kwargs.get("num_ctx", 32000) + } + if not kwargs.get("base_url", ""): + base_url = os.getenv("IBM_ENDPOINT", "https://us-south.ml.cloud.ibm.com") + else: + base_url = kwargs.get("base_url") + + return ChatWatsonx( + model_id=kwargs.get("model_name", "ibm/granite-vision-3.1-2b-preview"), + url=base_url, + project_id=os.getenv("IBM_PROJECT_ID"), + apikey=os.getenv("IBM_API_KEY"), + params=parameters + ) + elif provider == "moonshot": + return ChatOpenAI( + model=kwargs.get("model_name", "moonshot-v1-32k-vision-preview"), + temperature=kwargs.get("temperature", 0.0), + base_url=os.getenv("MOONSHOT_ENDPOINT"), + api_key=os.getenv("MOONSHOT_API_KEY"), + ) + elif provider == "unbound": + return ChatOpenAI( + model=kwargs.get("model_name", "gpt-4o-mini"), + temperature=kwargs.get("temperature", 0.0), + base_url=os.getenv("UNBOUND_ENDPOINT", "https://api.getunbound.ai"), + api_key=api_key, + ) + elif provider == "siliconflow": + if not kwargs.get("api_key", ""): + api_key = os.getenv("SiliconFLOW_API_KEY", "") + else: + api_key = kwargs.get("api_key") + if not kwargs.get("base_url", ""): + base_url = os.getenv("SiliconFLOW_ENDPOINT", "") + else: + base_url = kwargs.get("base_url") + return ChatOpenAI( + api_key=api_key, + base_url=base_url, + model_name=kwargs.get("model_name", "Qwen/QwQ-32B"), + temperature=kwargs.get("temperature", 0.0), + ) + elif provider == "modelscope": + if not kwargs.get("api_key", ""): + api_key = os.getenv("MODELSCOPE_API_KEY", "") + else: + api_key = kwargs.get("api_key") + if not kwargs.get("base_url", ""): + base_url = os.getenv("MODELSCOPE_ENDPOINT", "") + else: + base_url = kwargs.get("base_url") + return ChatOpenAI( + api_key=api_key, + base_url=base_url, + model_name=kwargs.get("model_name", "Qwen/QwQ-32B"), + temperature=kwargs.get("temperature", 0.0), + ) + else: + raise ValueError(f"Unsupported provider: {provider}")