Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 31 additions & 55 deletions examples/try.py
Original file line number Diff line number Diff line change
@@ -1,74 +1,50 @@
import os
import sys

from langchain_anthropic import ChatAnthropic
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI

from mlx_use.llm.factory import LLMFactory
from mlx_use.llm.providers import LLMConfig, LLM

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import argparse
import asyncio

from mlx_use import Agent
from pydantic import SecretStr
from mlx_use.controller.service import Controller

config = LLMConfig(
provider_name=LLM.OPENAI,
env_var="OPENAI_API_KEY",
params={"model": "gpt-4o"}
)

def set_llm(llm_provider:str = None):
if not llm_provider:
raise ValueError("No llm provider was set")

if llm_provider == "OAI" and os.getenv('OPENAI_API_KEY'):
return ChatOpenAI(model='gpt-4', api_key=SecretStr(os.getenv('OPENAI_API_KEY')))

if llm_provider == "google" and os.getenv('GEMINI_API_KEY'):
return ChatGoogleGenerativeAI(model='gemini-2.0-flash-exp', api_key=SecretStr(os.getenv('GEMINI_API_KEY')))

if llm_provider == "anthropic" and os.getenv('ANTHROPIC_API_KEY'):
return ChatAnthropic(model='claude-3-sonnet-20240229', api_key=SecretStr(os.getenv('ANTHROPIC_API_KEY')))

return None

# Try to set LLM based on available API keys
llm = None
if os.getenv('GEMINI_API_KEY'):
llm = set_llm('google')
elif os.getenv('OPENAI_API_KEY'):
llm = set_llm('OAI')
elif os.getenv('ANTHROPIC_API_KEY'):
llm = set_llm('anthropic')

if not llm:
raise ValueError("No API keys found. Please set at least one of GEMINI_API_KEY, OPENAI_API_KEY, or ANTHROPIC_API_KEY in your .env file")
provider = LLMFactory.create_provider(config)
llm = provider.get_llm()

controller = Controller()


async def main():

agent_greeting = Agent(
task='Say "Hi there $whoami, What can I do for you today?"',
llm=llm,
controller=controller,
use_vision=False,
max_actions_per_step=1,
max_failures=5
)

await agent_greeting.run(max_steps=25)
task = input("Enter the task: ")

agent_task = Agent(
task=task,
llm=llm,
controller=controller,
use_vision=False,
max_actions_per_step=4,
max_failures=5
)

await agent_task.run(max_steps=25)
agent_greeting = Agent(
task='Say "Hi there $whoami, What can I do for you today?"',
llm=llm,
controller=controller,
use_vision=False,
max_actions_per_step=1,
max_failures=5
)

await agent_greeting.run(max_steps=25)
task = input("Enter the task: ")

agent_task = Agent(
task=task,
llm=llm,
controller=controller,
use_vision=False,
max_actions_per_step=4,
max_failures=5
)

await agent_task.run(max_steps=25)


asyncio.run(main())
10 changes: 10 additions & 0 deletions mlx_use/llm/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from mlx_use.llm.providers import LLMProvider, GenericLLMProvider, LLMConfig


class LLMFactory:
"""Factory for creating LLM providers"""

@staticmethod
def create_provider(config: LLMConfig) -> LLMProvider:
"""Create an LLM provider based on the given name"""
return GenericLLMProvider(config)
66 changes: 66 additions & 0 deletions mlx_use/llm/providers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os
from abc import ABC, abstractmethod
from enum import Enum
from typing import Optional, Dict, Any

from langchain_anthropic import ChatAnthropic
from langchain_core.language_models import BaseChatModel
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI
from pydantic import SecretStr


class LLM(Enum):
DEEPSEEK = "deepseek"
CLAUDE = "claude"
OPENAI = "openai"
GEMINI = "gemini"


class LLMConfig:
"""Configuration class for LLM providers."""

PROVIDERS = {
LLM.DEEPSEEK: {"class": ChatOpenAI, "default_params": {"base_url": "https://api.deepseek.com/v1", "model": "deepseek-chat"}},
LLM.CLAUDE: {"class": ChatAnthropic, "default_params": {"model_name": "claude-3-7-sonnet-20250219"}},
LLM.OPENAI: {"class": ChatOpenAI, "default_params": {"model": "gpt-4o"}},
LLM.GEMINI: {"class": ChatGoogleGenerativeAI, "default_params": {"model": "gemini-2.0-flash"}},
}

def __init__(self, provider_name: LLM, env_var: str, params: Optional[Dict[str, Any]] = None):
self.provider_class = self.PROVIDERS[provider_name]["class"]
self.env_var = env_var
self.params = {**self.PROVIDERS[provider_name].get("default_params", {}), **(params or {})}

def get_api_key(self) -> str:
api_key = os.getenv(self.env_var)
if not api_key:
raise ValueError(f"{self.env_var} environment variable not set")
return api_key


class LLMProvider(ABC):
"""Strategy interface for LLM providers"""

@abstractmethod
def get_llm(self, config: Optional[LLMConfig] = None) -> BaseChatModel:
"""Return the LLM instance"""
pass


class GenericLLMProvider(LLMProvider):
"""Generic LLM provider implementation using dynamic configuration."""

def __init__(self, config: Optional[LLMConfig] = None):
self.config = config

def get_llm(self, params: Optional[Dict[str, Any]] = None) -> BaseChatModel:
try:
merged_params = {**self.config.params, **(params or {})}
return self.config.provider_class(
api_key=SecretStr(self.config.get_api_key()),
**merged_params
)
except Exception as e:
print(f"Error initializing {self.config.provider_class} LLM: {e}")
raise