Skip to content

Commit bd44e44

Browse files
committed
Updated the temperature parameter of e O‑series models
This pull request fixes an issue in the OpenAILLM class where O‑series models (e.g., o3-mini, o4-mini) should not receive a temperature parameter. Previously, temperature was always included in the API call parameters without any warnings or errors. I have now corrected this behavior and added a unit test named test_openai_model.py.
1 parent e152825 commit bd44e44

File tree

2 files changed

+126
-17
lines changed

2 files changed

+126
-17
lines changed

openevolve/llm/openai.py

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414

1515
logger = logging.getLogger(__name__)
1616

17+
_O_SERIES_MODELS = {"o1", "o1-mini", "o1-pro"
18+
"o-3", "o3-mini", "o3-pro",
19+
"o4-mini"}
20+
1721

1822
class OpenAILLM(LLMInterface):
1923
"""LLM interface using OpenAI-compatible APIs"""
@@ -64,22 +68,49 @@ async def generate_with_context(
6468
formatted_messages = [{"role": "system", "content": system_message}]
6569
formatted_messages.extend(messages)
6670

71+
kwargs.setdefault("temperature", self.temperature)
72+
# define params
73+
params: Dict[str, Any] = {
74+
"model": self.model,
75+
"messages": formatted_messages,
76+
}
77+
6778
# Set up generation parameters
68-
if self.api_base == "https://api.openai.com/v1" and str(self.model).lower().startswith("o"):
69-
# For o-series models
70-
params = {
71-
"model": self.model,
72-
"messages": formatted_messages,
73-
"max_completion_tokens": kwargs.get("max_tokens", self.max_tokens),
74-
}
79+
# if self.api_base == "https://api.openai.com/v1" and str(self.model).lower().startswith("o"):
80+
# # For o-series models
81+
# params = {
82+
# "model": self.model,
83+
# "messages": formatted_messages,
84+
# "max_completion_tokens": kwargs.get("max_tokens", self.max_tokens),
85+
# }
86+
# else:
87+
# params = {
88+
# "model": self.model,
89+
# "messages": formatted_messages,
90+
# "temperature": kwargs.get("temperature", self.temperature),
91+
# "top_p": kwargs.get("top_p", self.top_p),
92+
# "max_tokens": kwargs.get("max_tokens", self.max_tokens),
93+
# }
94+
95+
if self.api_base == "https://api.openai.com/v1":
96+
params["max_completion_tokens"] = kwargs.get(
97+
"max_tokens", self.max_tokens)
7598
else:
76-
params = {
77-
"model": self.model,
78-
"messages": formatted_messages,
79-
"temperature": kwargs.get("temperature", self.temperature),
80-
"top_p": kwargs.get("top_p", self.top_p),
81-
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
82-
}
99+
params["max_tokens"] = kwargs.get("max_tokens", self.max_tokens)
100+
101+
get_model = str(self.model).lower()
102+
if self.api_base == "https://api.openai.com/v1" and get_model in _O_SERIES_MODELS:
103+
# if user sets up temperature in config, will have a warning
104+
if "temperature" in kwargs:
105+
logger.warning(
106+
f"Model {self.model!r} doesn't support temperature"
107+
)
108+
109+
else:
110+
params["temperature"] = kwargs.get("temperature", self.temperature)
111+
params["top_p"] = kwargs.get("top_p", self.top_p)
112+
113+
print("[DEBUG] LLM params:", params.keys())
83114

84115
# Add seed parameter for reproducibility if configured
85116
# Skip seed parameter for Google AI Studio endpoint as it doesn't support it
@@ -104,10 +135,12 @@ async def generate_with_context(
104135
return response
105136
except asyncio.TimeoutError:
106137
if attempt < retries:
107-
logger.warning(f"Timeout on attempt {attempt + 1}/{retries + 1}. Retrying...")
138+
logger.warning(
139+
f"Timeout on attempt {attempt + 1}/{retries + 1}. Retrying...")
108140
await asyncio.sleep(retry_delay)
109141
else:
110-
logger.error(f"All {retries + 1} attempts failed with timeout")
142+
logger.error(
143+
f"All {retries + 1} attempts failed with timeout")
111144
raise
112145
except Exception as e:
113146
if attempt < retries:
@@ -116,7 +149,8 @@ async def generate_with_context(
116149
)
117150
await asyncio.sleep(retry_delay)
118151
else:
119-
logger.error(f"All {retries + 1} attempts failed with error: {str(e)}")
152+
logger.error(
153+
f"All {retries + 1} attempts failed with error: {str(e)}")
120154
raise
121155

122156
async def _call_api(self, params: Dict[str, Any]) -> str:

tests/test_openai_model.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
2+
"""
3+
Tests for O series model config check
4+
"""
5+
import asyncio
6+
import unittest
7+
from types import SimpleNamespace
8+
from unittest.mock import MagicMock, patch
9+
10+
from openevolve.llm.openai import OpenAILLM
11+
12+
13+
class TestOpenAILLM(unittest.TestCase):
14+
15+
def setUp(self):
16+
self.model_cfg = SimpleNamespace(
17+
name="test-model",
18+
system_message="SYS",
19+
temperature=0.7,
20+
top_p=0.98,
21+
max_tokens=42,
22+
timeout=1,
23+
retries=0,
24+
retry_delay=0,
25+
api_base="https://api.openai.com/v1",
26+
api_key="fake-key",
27+
random_seed=123,
28+
)
29+
30+
fake_choice = SimpleNamespace(message=SimpleNamespace(content=" OK"))
31+
fake_response = SimpleNamespace(choices=[fake_choice])
32+
33+
self.fake_client = MagicMock()
34+
self.fake_client.chat.completions.create.return_value = fake_response
35+
36+
def test_generate_happy_path(self):
37+
38+
with patch("openevolve.llm.openai.openai.OpenAI", return_value=self.fake_client) as _:
39+
llm = OpenAILLM(self.model_cfg)
40+
41+
result = asyncio.get_event_loop().run_until_complete(
42+
llm.generate("hello world")
43+
)
44+
45+
self.assertEqual(result, " OK")
46+
47+
called_kwargs = self.fake_client.chat.completions.create.call_args.kwargs
48+
msgs = called_kwargs["messages"]
49+
self.assertEqual(msgs[0]["role"], "system")
50+
self.assertEqual(msgs[0]["content"], "SYS")
51+
self.assertEqual(msgs[1]["role"], "user")
52+
self.assertEqual(msgs[1]["content"], "hello world")
53+
54+
def test_max_completion_tokens_branch(self):
55+
self.model_cfg.name = "o4-mini"
56+
with patch("openevolve.llm.openai.openai.OpenAI", return_value=self.fake_client):
57+
llm = OpenAILLM(self.model_cfg)
58+
asyncio.get_event_loop().run_until_complete(llm.generate("foo"))
59+
60+
called = self.fake_client.chat.completions.create.call_args.kwargs
61+
62+
self.assertIn("max_completion_tokens", called)
63+
self.assertNotIn("max_tokens", called)
64+
65+
def test_fallback_max_tokens_branch(self):
66+
67+
self.model_cfg.api_base = "https://my.custom.endpoint"
68+
with patch("openevolve.llm.openai.openai.OpenAI", return_value=self.fake_client):
69+
llm = OpenAILLM(self.model_cfg)
70+
asyncio.get_event_loop().run_until_complete(llm.generate("bar"))
71+
72+
called = self.fake_client.chat.completions.create.call_args.kwargs
73+
74+
self.assertIn("max_tokens", called)
75+
self.assertNotIn("max_completion_tokens", called)

0 commit comments

Comments
 (0)