Skip to content

Commit f07e92d

Browse files
committed
feat: add Anthropic Claude API support
- Add AnthropicLLM class for Claude API integration - Update LLMEnsemble to support both OpenAI and Anthropic models - Add automatic model detection and API base configuration - Add comprehensive test coverage for new functionality - Update dependencies to include anthropic package
1 parent babe3f9 commit f07e92d

File tree

6 files changed

+273
-3
lines changed

6 files changed

+273
-3
lines changed

openevolve/config.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,24 @@ class LLMConfig:
3636
retries: int = 3
3737
retry_delay: int = 5
3838

39+
def __post_init__(self):
40+
"""Set up API key from environment if not provided"""
41+
if not self.api_key:
42+
# Try to get API key from environment
43+
if self.primary_model.startswith("claude-") or self.primary_model.startswith(
44+
"anthropic/"
45+
):
46+
self.api_key = os.environ.get("ANTHROPIC_API_KEY")
47+
else:
48+
self.api_key = os.environ.get("OPENAI_API_KEY")
49+
50+
# Set default API base based on model type
51+
if self.api_base == "https://api.openai.com/v1":
52+
if self.primary_model.startswith("claude-") or self.primary_model.startswith(
53+
"anthropic/"
54+
):
55+
self.api_base = "https://api.anthropic.com/v1"
56+
3957

4058
@dataclass
4159
class PromptConfig:

openevolve/llm/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@
55
from openevolve.llm.base import LLMInterface
66
from openevolve.llm.ensemble import LLMEnsemble
77
from openevolve.llm.openai import OpenAILLM
8+
from openevolve.llm.anthropic import AnthropicLLM
89

9-
__all__ = ["LLMInterface", "OpenAILLM", "LLMEnsemble"]
10+
__all__ = ["LLMInterface", "OpenAILLM", "AnthropicLLM", "LLMEnsemble"]

openevolve/llm/anthropic.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
"""
2+
Anthropic Claude API interface for LLMs
3+
"""
4+
5+
import asyncio
6+
import logging
7+
from typing import Any, Dict, List, Optional
8+
9+
import anthropic
10+
11+
from openevolve.config import LLMConfig
12+
from openevolve.llm.base import LLMInterface
13+
14+
logger = logging.getLogger(__name__)
15+
16+
17+
class AnthropicLLM(LLMInterface):
18+
"""LLM interface using Anthropic's Claude API"""
19+
20+
def __init__(
21+
self,
22+
config: LLMConfig,
23+
model: Optional[str] = None,
24+
):
25+
self.config = config
26+
self.model = model or config.primary_model
27+
28+
# Set up API client
29+
self.client = anthropic.Anthropic(
30+
api_key=config.api_key,
31+
base_url=config.api_base,
32+
)
33+
34+
logger.info(f"Initialized Anthropic LLM with model: {self.model}")
35+
36+
async def generate(self, prompt: str, **kwargs) -> str:
37+
"""Generate text from a prompt"""
38+
return await self.generate_with_context(
39+
system_message=self.config.system_message,
40+
messages=[{"role": "user", "content": prompt}],
41+
**kwargs,
42+
)
43+
44+
async def generate_with_context(
45+
self, system_message: str, messages: List[Dict[str, str]], **kwargs
46+
) -> str:
47+
"""Generate text using a system message and conversational context"""
48+
# Prepare messages for Claude format
49+
formatted_messages = []
50+
for msg in messages:
51+
formatted_messages.append({"role": msg["role"], "content": msg["content"]})
52+
53+
# Set up generation parameters
54+
params = {
55+
"model": self.model,
56+
"system": system_message,
57+
"messages": formatted_messages,
58+
"max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
59+
"temperature": kwargs.get("temperature", self.config.temperature),
60+
"top_p": kwargs.get("top_p", self.config.top_p),
61+
}
62+
63+
# Attempt the API call with retries
64+
retries = kwargs.get("retries", self.config.retries)
65+
retry_delay = kwargs.get("retry_delay", self.config.retry_delay)
66+
timeout = kwargs.get("timeout", self.config.timeout)
67+
68+
for attempt in range(retries + 1):
69+
try:
70+
response = await asyncio.wait_for(self._call_api(params), timeout=timeout)
71+
return response
72+
except asyncio.TimeoutError:
73+
if attempt < retries:
74+
logger.warning(f"Timeout on attempt {attempt + 1}/{retries + 1}. Retrying...")
75+
await asyncio.sleep(retry_delay)
76+
else:
77+
logger.error(f"All {retries + 1} attempts failed with timeout")
78+
raise
79+
except Exception as e:
80+
if attempt < retries:
81+
logger.warning(
82+
f"Error on attempt {attempt + 1}/{retries + 1}: {str(e)}. Retrying..."
83+
)
84+
await asyncio.sleep(retry_delay)
85+
else:
86+
logger.error(f"All {retries + 1} attempts failed with error: {str(e)}")
87+
raise
88+
89+
async def _call_api(self, params: Dict[str, Any]) -> str:
90+
"""Make the actual API call"""
91+
# Use asyncio to run the blocking API call in a thread pool
92+
loop = asyncio.get_event_loop()
93+
response = await loop.run_in_executor(None, lambda: self.client.messages.create(**params))
94+
95+
# Extract the response content
96+
return response.content[0].text

openevolve/llm/ensemble.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,28 @@
1010
from openevolve.config import LLMConfig
1111
from openevolve.llm.base import LLMInterface
1212
from openevolve.llm.openai import OpenAILLM
13+
from openevolve.llm.anthropic import AnthropicLLM
1314

1415
logger = logging.getLogger(__name__)
1516

1617

18+
def create_llm(config: LLMConfig, model: str) -> LLMInterface:
19+
"""Create an LLM instance based on the model name"""
20+
if model.startswith("claude-") or model.startswith("anthropic/"):
21+
return AnthropicLLM(config, model=model)
22+
else:
23+
return OpenAILLM(config, model=model)
24+
25+
1726
class LLMEnsemble:
1827
"""Ensemble of LLMs for generating diverse code modifications"""
1928

2029
def __init__(self, config: LLMConfig):
2130
self.config = config
2231

2332
# Initialize primary and secondary models
24-
self.primary_model = OpenAILLM(config, model=config.primary_model)
25-
self.secondary_model = OpenAILLM(config, model=config.secondary_model)
33+
self.primary_model = create_llm(config, config.primary_model)
34+
self.secondary_model = create_llm(config, config.secondary_model)
2635

2736
# Model weights for sampling
2837
self._weights = [

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ authors = [
1414
]
1515
dependencies = [
1616
"openai>=1.0.0",
17+
"anthropic>=0.8.0",
1718
"pyyaml>=6.0",
1819
"numpy>=1.22.0",
1920
"tqdm>=4.64.0",

tests/test_llm.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
"""
2+
Tests for LLM implementations
3+
"""
4+
5+
import asyncio
6+
import unittest
7+
from unittest.mock import AsyncMock, MagicMock, patch
8+
9+
from openevolve.config import LLMConfig
10+
from openevolve.llm.anthropic import AnthropicLLM
11+
from openevolve.llm.openai import OpenAILLM
12+
13+
14+
class TestLLMImplementations(unittest.TestCase):
15+
"""Tests for LLM implementations"""
16+
17+
def setUp(self):
18+
"""Set up test configuration"""
19+
self.config = LLMConfig(
20+
primary_model="test-model",
21+
api_key="test-key",
22+
api_base="https://test.api",
23+
)
24+
25+
@patch("anthropic.Anthropic")
26+
async def test_anthropic_llm_generate(self, mock_anthropic):
27+
"""Test Anthropic LLM generate method"""
28+
# Set up mock response
29+
mock_response = MagicMock()
30+
mock_response.content = [MagicMock(text="Test response")]
31+
mock_anthropic.return_value.messages.create.return_value = mock_response
32+
33+
# Create LLM instance
34+
llm = AnthropicLLM(self.config)
35+
36+
# Test generate
37+
response = await llm.generate("Test prompt")
38+
self.assertEqual(response, "Test response")
39+
40+
# Verify API call
41+
mock_anthropic.return_value.messages.create.assert_called_once()
42+
call_args = mock_anthropic.return_value.messages.create.call_args[1]
43+
self.assertEqual(call_args["model"], "test-model")
44+
self.assertEqual(call_args["messages"][0]["role"], "user")
45+
self.assertEqual(call_args["messages"][0]["content"], "Test prompt")
46+
47+
@patch("anthropic.Anthropic")
48+
async def test_anthropic_llm_generate_with_context(self, mock_anthropic):
49+
"""Test Anthropic LLM generate_with_context method"""
50+
# Set up mock response
51+
mock_response = MagicMock()
52+
mock_response.content = [MagicMock(text="Test response")]
53+
mock_anthropic.return_value.messages.create.return_value = mock_response
54+
55+
# Create LLM instance
56+
llm = AnthropicLLM(self.config)
57+
58+
# Test generate_with_context
59+
messages = [
60+
{"role": "user", "content": "Test message 1"},
61+
{"role": "assistant", "content": "Test response 1"},
62+
{"role": "user", "content": "Test message 2"},
63+
]
64+
response = await llm.generate_with_context("Test system", messages)
65+
self.assertEqual(response, "Test response")
66+
67+
# Verify API call
68+
mock_anthropic.return_value.messages.create.assert_called_once()
69+
call_args = mock_anthropic.return_value.messages.create.call_args[1]
70+
self.assertEqual(call_args["model"], "test-model")
71+
self.assertEqual(call_args["system"], "Test system")
72+
self.assertEqual(len(call_args["messages"]), 3)
73+
self.assertEqual(call_args["messages"][0]["role"], "user")
74+
self.assertEqual(call_args["messages"][0]["content"], "Test message 1")
75+
76+
@patch("openai.OpenAI")
77+
async def test_openai_llm_generate(self, mock_openai):
78+
"""Test OpenAI LLM generate method"""
79+
# Set up mock response
80+
mock_response = MagicMock()
81+
mock_response.choices = [MagicMock(message=MagicMock(content="Test response"))]
82+
mock_openai.return_value.chat.completions.create.return_value = mock_response
83+
84+
# Create LLM instance
85+
llm = OpenAILLM(self.config)
86+
87+
# Test generate
88+
response = await llm.generate("Test prompt")
89+
self.assertEqual(response, "Test response")
90+
91+
# Verify API call
92+
mock_openai.return_value.chat.completions.create.assert_called_once()
93+
call_args = mock_openai.return_value.chat.completions.create.call_args[1]
94+
self.assertEqual(call_args["model"], "test-model")
95+
self.assertEqual(call_args["messages"][0]["role"], "user")
96+
self.assertEqual(call_args["messages"][0]["content"], "Test prompt")
97+
98+
@patch("openai.OpenAI")
99+
async def test_openai_llm_generate_with_context(self, mock_openai):
100+
"""Test OpenAI LLM generate_with_context method"""
101+
# Set up mock response
102+
mock_response = MagicMock()
103+
mock_response.choices = [MagicMock(message=MagicMock(content="Test response"))]
104+
mock_openai.return_value.chat.completions.create.return_value = mock_response
105+
106+
# Create LLM instance
107+
llm = OpenAILLM(self.config)
108+
109+
# Test generate_with_context
110+
messages = [
111+
{"role": "user", "content": "Test message 1"},
112+
{"role": "assistant", "content": "Test response 1"},
113+
{"role": "user", "content": "Test message 2"},
114+
]
115+
response = await llm.generate_with_context("Test system", messages)
116+
self.assertEqual(response, "Test response")
117+
118+
# Verify API call
119+
mock_openai.return_value.chat.completions.create.assert_called_once()
120+
call_args = mock_openai.return_value.chat.completions.create.call_args[1]
121+
self.assertEqual(call_args["model"], "test-model")
122+
self.assertEqual(call_args["messages"][0]["role"], "system")
123+
self.assertEqual(call_args["messages"][0]["content"], "Test system")
124+
self.assertEqual(len(call_args["messages"]), 4) # system + 3 messages
125+
126+
def test_llm_config_model_detection(self):
127+
"""Test LLM configuration model type detection"""
128+
# Test OpenAI model
129+
config = LLMConfig(primary_model="gpt-4")
130+
self.assertEqual(config.api_base, "https://api.openai.com/v1")
131+
132+
# Test Claude model
133+
config = LLMConfig(primary_model="claude-3-sonnet")
134+
self.assertEqual(config.api_base, "https://api.anthropic.com/v1")
135+
136+
# Test custom API base
137+
config = LLMConfig(
138+
primary_model="claude-3-sonnet",
139+
api_base="https://custom.api",
140+
)
141+
self.assertEqual(config.api_base, "https://custom.api")
142+
143+
144+
if __name__ == "__main__":
145+
unittest.main()

0 commit comments

Comments
 (0)