diff --git a/examples/try.py b/examples/try.py index 87d7137..a468b68 100644 --- a/examples/try.py +++ b/examples/try.py @@ -1,74 +1,50 @@ import os import sys -from langchain_anthropic import ChatAnthropic -from langchain_openai import ChatOpenAI -from langchain_google_genai import ChatGoogleGenerativeAI - +from mlx_use.llm.factory import LLMFactory +from mlx_use.llm.providers import LLMConfig, LLM sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -import argparse import asyncio from mlx_use import Agent -from pydantic import SecretStr from mlx_use.controller.service import Controller +config = LLMConfig( + provider_name=LLM.OPENAI, + env_var="OPENAI_API_KEY", + params={"model": "gpt-4o"} +) -def set_llm(llm_provider:str = None): - if not llm_provider: - raise ValueError("No llm provider was set") - - if llm_provider == "OAI" and os.getenv('OPENAI_API_KEY'): - return ChatOpenAI(model='gpt-4', api_key=SecretStr(os.getenv('OPENAI_API_KEY'))) - - if llm_provider == "google" and os.getenv('GEMINI_API_KEY'): - return ChatGoogleGenerativeAI(model='gemini-2.0-flash-exp', api_key=SecretStr(os.getenv('GEMINI_API_KEY'))) - - if llm_provider == "anthropic" and os.getenv('ANTHROPIC_API_KEY'): - return ChatAnthropic(model='claude-3-sonnet-20240229', api_key=SecretStr(os.getenv('ANTHROPIC_API_KEY'))) - - return None - -# Try to set LLM based on available API keys -llm = None -if os.getenv('GEMINI_API_KEY'): - llm = set_llm('google') -elif os.getenv('OPENAI_API_KEY'): - llm = set_llm('OAI') -elif os.getenv('ANTHROPIC_API_KEY'): - llm = set_llm('anthropic') - -if not llm: - raise ValueError("No API keys found. Please set at least one of GEMINI_API_KEY, OPENAI_API_KEY, or ANTHROPIC_API_KEY in your .env file") +provider = LLMFactory.create_provider(config) +llm = provider.get_llm() controller = Controller() async def main(): - - agent_greeting = Agent( - task='Say "Hi there $whoami, What can I do for you today?"', - llm=llm, - controller=controller, - use_vision=False, - max_actions_per_step=1, - max_failures=5 - ) - - await agent_greeting.run(max_steps=25) - task = input("Enter the task: ") - - agent_task = Agent( - task=task, - llm=llm, - controller=controller, - use_vision=False, - max_actions_per_step=4, - max_failures=5 - ) - - await agent_task.run(max_steps=25) + agent_greeting = Agent( + task='Say "Hi there $whoami, What can I do for you today?"', + llm=llm, + controller=controller, + use_vision=False, + max_actions_per_step=1, + max_failures=5 + ) + + await agent_greeting.run(max_steps=25) + task = input("Enter the task: ") + + agent_task = Agent( + task=task, + llm=llm, + controller=controller, + use_vision=False, + max_actions_per_step=4, + max_failures=5 + ) + + await agent_task.run(max_steps=25) asyncio.run(main()) diff --git a/mlx_use/llm/factory.py b/mlx_use/llm/factory.py new file mode 100644 index 0000000..06c3dcb --- /dev/null +++ b/mlx_use/llm/factory.py @@ -0,0 +1,10 @@ +from mlx_use.llm.providers import LLMProvider, GenericLLMProvider, LLMConfig + + +class LLMFactory: + """Factory for creating LLM providers""" + + @staticmethod + def create_provider(config: LLMConfig) -> LLMProvider: + """Create an LLM provider based on the given name""" + return GenericLLMProvider(config) diff --git a/mlx_use/llm/providers.py b/mlx_use/llm/providers.py new file mode 100644 index 0000000..32eef1a --- /dev/null +++ b/mlx_use/llm/providers.py @@ -0,0 +1,66 @@ +import os +from abc import ABC, abstractmethod +from enum import Enum +from typing import Optional, Dict, Any + +from langchain_anthropic import ChatAnthropic +from langchain_core.language_models import BaseChatModel +from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_openai import ChatOpenAI +from pydantic import SecretStr + + +class LLM(Enum): + DEEPSEEK = "deepseek" + CLAUDE = "claude" + OPENAI = "openai" + GEMINI = "gemini" + + +class LLMConfig: + """Configuration class for LLM providers.""" + + PROVIDERS = { + LLM.DEEPSEEK: {"class": ChatOpenAI, "default_params": {"base_url": "https://api.deepseek.com/v1", "model": "deepseek-chat"}}, + LLM.CLAUDE: {"class": ChatAnthropic, "default_params": {"model_name": "claude-3-7-sonnet-20250219"}}, + LLM.OPENAI: {"class": ChatOpenAI, "default_params": {"model": "gpt-4o"}}, + LLM.GEMINI: {"class": ChatGoogleGenerativeAI, "default_params": {"model": "gemini-2.0-flash"}}, + } + + def __init__(self, provider_name: LLM, env_var: str, params: Optional[Dict[str, Any]] = None): + self.provider_class = self.PROVIDERS[provider_name]["class"] + self.env_var = env_var + self.params = {**self.PROVIDERS[provider_name].get("default_params", {}), **(params or {})} + + def get_api_key(self) -> str: + api_key = os.getenv(self.env_var) + if not api_key: + raise ValueError(f"{self.env_var} environment variable not set") + return api_key + + +class LLMProvider(ABC): + """Strategy interface for LLM providers""" + + @abstractmethod + def get_llm(self, config: Optional[LLMConfig] = None) -> BaseChatModel: + """Return the LLM instance""" + pass + + +class GenericLLMProvider(LLMProvider): + """Generic LLM provider implementation using dynamic configuration.""" + + def __init__(self, config: Optional[LLMConfig] = None): + self.config = config + + def get_llm(self, params: Optional[Dict[str, Any]] = None) -> BaseChatModel: + try: + merged_params = {**self.config.params, **(params or {})} + return self.config.provider_class( + api_key=SecretStr(self.config.get_api_key()), + **merged_params + ) + except Exception as e: + print(f"Error initializing {self.config.provider_class} LLM: {e}") + raise