Skip to content

Commit 3fc9465

Browse files
authored
Merge pull request #47 from jvm123/feat-n-model-ensemble
Feature: Better support for LLM feedback and handling of LLM ensembles.
2 parents 166f77f + 659e128 commit 3fc9465

File tree

9 files changed

+301
-124
lines changed

9 files changed

+301
-124
lines changed

configs/default_config.yaml

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,21 @@ max_code_length: 10000 # Maximum allowed code length in character
1616

1717
# LLM configuration
1818
llm:
19-
# Primary model (used most frequently)
20-
primary_model: "gemini-2.0-flash-lite"
21-
primary_model_weight: 0.8 # Sampling weight for primary model
22-
23-
# Secondary model (used for occasional high-quality generations)
24-
secondary_model: "gemini-2.0-flash"
25-
secondary_model_weight: 0.2 # Sampling weight for secondary model
19+
# Models for evolution
20+
models:
21+
# List of available models with their weights
22+
- name: "gemini-2.0-flash-lite"
23+
weight: 0.8
24+
- name: "gemini-2.0-flash"
25+
weight: 0.2
26+
27+
# Models for LLM feedback
28+
evaluator_models:
29+
# List of available models with their weights
30+
- name: "gemini-2.0-flash-lite"
31+
weight: 0.8
32+
- name: "gemini-2.0-flash"
33+
weight: 0.2
2634

2735
# API configuration
2836
api_base: "https://generativelanguage.googleapis.com/v1beta/openai/" # Base URL for API (change for non-OpenAI models)
@@ -42,6 +50,7 @@ llm:
4250
prompt:
4351
template_dir: null # Custom directory for prompt templates
4452
system_message: "You are an expert coder helping to improve programs through evolution."
53+
evaluator_system_message: "You are an expert code reviewer."
4554

4655
# Number of examples to include in the prompt
4756
num_top_programs: 3 # Number of top-performing programs to include

openevolve/config.py

Lines changed: 105 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,41 @@
1111

1212

1313
@dataclass
14-
class LLMConfig:
15-
"""Configuration for LLM models"""
14+
class LLMModelConfig:
15+
"""Configuration for a single LLM model"""
1616

17-
# Primary model
18-
primary_model: str = "gemini-2.0-flash-lite"
19-
primary_model_weight: float = 0.8
17+
# API configuration
18+
api_base: str = None
19+
api_key: Optional[str] = None
20+
name: str = None
2021

21-
# Secondary model
22-
secondary_model: str = "gemini-2.0-flash"
23-
secondary_model_weight: float = 0.2
22+
# Weight for model in ensemble
23+
weight: float = 1.0
24+
25+
# Generation parameters
26+
system_message: Optional[str] = None
27+
temperature: float = None
28+
top_p: float = None
29+
max_tokens: int = None
30+
31+
# Request parameters
32+
timeout: int = None
33+
retries: int = None
34+
retry_delay: int = None
35+
36+
37+
@dataclass
38+
class LLMConfig(LLMModelConfig):
39+
"""Configuration for LLM models"""
2440

2541
# API configuration
2642
api_base: str = "https://api.openai.com/v1"
27-
api_key: Optional[str] = None
43+
name: str = "gpt-4o"
2844

2945
# Generation parameters
46+
system_message: Optional[str] = (
47+
"You are an expert coder helping to improve programs through evolution."
48+
)
3049
temperature: float = 0.7
3150
top_p: float = 0.95
3251
max_tokens: int = 4096
@@ -36,13 +55,69 @@ class LLMConfig:
3655
retries: int = 3
3756
retry_delay: int = 5
3857

58+
# n-model configuration for evolution LLM ensemble
59+
models: List[LLMModelConfig] = field(default_factory=lambda: [LLMModelConfig()])
60+
61+
# n-model configuration for evaluator LLM ensemble
62+
evaluator_models: List[LLMModelConfig] = field(default_factory=lambda: [])
63+
64+
# Backwardes compatibility with primary_model(_weight) options
65+
primary_model: str = "gemini-2.0-flash-lite"
66+
primary_model_weight: float = 0.8
67+
secondary_model: str = "gemini-2.0-flash"
68+
secondary_model_weight: float = 0.2
69+
70+
def __post_init__(self):
71+
"""Post-initialization to set up model configurations"""
72+
# Handle backward compatibility for primary_model(_weight) and secondary_model(_weight).
73+
if (self.primary_model or self.primary_model_weight) and len(self.models) < 1:
74+
# Ensure we have a primary model
75+
self.models.append(LLMModelConfig())
76+
if self.primary_model:
77+
self.models[0].name = self.primary_model
78+
if self.primary_model_weight:
79+
self.models[0].weight = self.primary_model_weight
80+
81+
if (self.secondary_model or self.secondary_model_weight) and len(self.models) < 2:
82+
# Ensure we have a second model
83+
self.models.append(LLMModelConfig())
84+
if self.secondary_model:
85+
self.models[1].name = self.secondary_model
86+
if self.secondary_model_weight:
87+
self.models[1].weight = self.secondary_model_weight
88+
89+
# If no evaluator models are defined, use the same models as for evolution
90+
if not self.evaluator_models or len(self.evaluator_models) < 1:
91+
self.evaluator_models = self.models.copy()
92+
93+
# Update models with shared configuration values
94+
shared_config = {
95+
"api_base": self.api_base,
96+
"api_key": self.api_key,
97+
"temperature": self.temperature,
98+
"top_p": self.top_p,
99+
"max_tokens": self.max_tokens,
100+
"timeout": self.timeout,
101+
"retries": self.retries,
102+
"retry_delay": self.retry_delay,
103+
}
104+
self.update_model_params(shared_config)
105+
106+
def update_model_params(self, args: Dict[str, Any], overwrite: bool = False) -> None:
107+
"""Update model parameters for all models"""
108+
for model in self.models + self.evaluator_models:
109+
for key, value in args.items():
110+
if overwrite or getattr(model, key, None) is None:
111+
setattr(model, key, value)
112+
39113

40114
@dataclass
41115
class PromptConfig:
42116
"""Configuration for prompt generation"""
43117

44118
template_dir: Optional[str] = None
45119
system_message: str = "You are an expert coder helping to improve programs through evolution."
120+
evaluator_system_message: str = """You are an expert code reviewer."""
46121

47122
# Number of examples to include in the prompt
48123
num_top_programs: int = 3
@@ -155,7 +230,14 @@ def from_dict(cls, config_dict: Dict[str, Any]) -> "Config":
155230

156231
# Update nested configs
157232
if "llm" in config_dict:
158-
config.llm = LLMConfig(**config_dict["llm"])
233+
llm_dict = config_dict["llm"]
234+
if "models" in llm_dict:
235+
llm_dict["models"] = [LLMModelConfig(**m) for m in llm_dict["models"]]
236+
if "evaluator_models" in llm_dict:
237+
llm_dict["evaluator_models"] = [
238+
LLMModelConfig(**m) for m in llm_dict["evaluator_models"]
239+
]
240+
config.llm = LLMConfig(**llm_dict)
159241
if "prompt" in config_dict:
160242
config.prompt = PromptConfig(**config_dict["prompt"])
161243
if "database" in config_dict:
@@ -176,10 +258,8 @@ def to_dict(self) -> Dict[str, Any]:
176258
"random_seed": self.random_seed,
177259
# Component configurations
178260
"llm": {
179-
"primary_model": self.llm.primary_model,
180-
"primary_model_weight": self.llm.primary_model_weight,
181-
"secondary_model": self.llm.secondary_model,
182-
"secondary_model_weight": self.llm.secondary_model_weight,
261+
"models": self.llm.models,
262+
"evaluator_models": self.llm.evaluator_models,
183263
"api_base": self.llm.api_base,
184264
"temperature": self.llm.temperature,
185265
"top_p": self.llm.top_p,
@@ -191,6 +271,7 @@ def to_dict(self) -> Dict[str, Any]:
191271
"prompt": {
192272
"template_dir": self.prompt.template_dir,
193273
"system_message": self.prompt.system_message,
274+
"evaluator_system_message": self.prompt.evaluator_system_message,
194275
"num_top_programs": self.prompt.num_top_programs,
195276
"num_diverse_programs": self.prompt.num_diverse_programs,
196277
"use_template_stochasticity": self.prompt.use_template_stochasticity,
@@ -245,16 +326,17 @@ def to_yaml(self, path: Union[str, Path]) -> None:
245326
def load_config(config_path: Optional[Union[str, Path]] = None) -> Config:
246327
"""Load configuration from a YAML file or use defaults"""
247328
if config_path and os.path.exists(config_path):
248-
return Config.from_yaml(config_path)
329+
config = Config.from_yaml(config_path)
330+
else:
331+
config = Config()
332+
333+
# Use environment variables if available
334+
api_key = os.environ.get("OPENAI_API_KEY")
335+
api_base = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1")
249336

250-
# Use environment variables if available
251-
api_key = os.environ.get("OPENAI_API_KEY")
252-
api_base = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1")
337+
config.llm.update_model_params({"api_key": api_key, "api_base": api_base})
253338

254-
config = Config()
255-
if api_key:
256-
config.llm.api_key = api_key
257-
if api_base:
258-
config.llm.api_base = api_base
339+
# Make the system message available to the individual models, in case it is not provided from the prompt sampler
340+
config.llm.update_model_params({"system_message": config.prompt.system_message})
259341

260342
return config

openevolve/controller.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,15 +96,25 @@ def __init__(
9696
self.file_extension = f".{self.file_extension}"
9797

9898
# Initialize components
99-
self.llm_ensemble = LLMEnsemble(self.config.llm)
99+
self.llm_ensemble = LLMEnsemble(self.config.llm.models)
100+
self.llm_evaluator_ensemble = LLMEnsemble(self.config.llm.evaluator_models)
101+
100102
self.prompt_sampler = PromptSampler(self.config.prompt)
103+
self.evaluator_prompt_sampler = PromptSampler(self.config.prompt)
104+
self.evaluator_prompt_sampler.set_templates("evaluator_system_message")
101105

102106
# Pass random seed to database if specified
103107
if self.config.random_seed is not None:
104108
self.config.database.random_seed = self.config.random_seed
105109

106110
self.database = ProgramDatabase(self.config.database)
107-
self.evaluator = Evaluator(self.config.evaluator, evaluation_file, self.llm_ensemble)
111+
112+
self.evaluator = Evaluator(
113+
self.config.evaluator,
114+
evaluation_file,
115+
self.llm_evaluator_ensemble,
116+
self.evaluator_prompt_sampler,
117+
)
108118

109119
logger.info(f"Initialized OpenEvolve with {initial_program_path} " f"and {evaluation_file}")
110120

openevolve/evaluator.py

Lines changed: 48 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
import uuid
1515
from pathlib import Path
1616
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17+
import traceback
1718

1819
from openevolve.config import EvaluatorConfig
1920
from openevolve.llm.ensemble import LLMEnsemble
2021
from openevolve.utils.async_utils import TaskPool, run_in_executor
22+
from openevolve.prompt.sampler import PromptSampler
2123
from openevolve.utils.format_utils import format_metrics_safe
2224

2325
logger = logging.getLogger(__name__)
@@ -36,10 +38,12 @@ def __init__(
3638
config: EvaluatorConfig,
3739
evaluation_file: str,
3840
llm_ensemble: Optional[LLMEnsemble] = None,
41+
prompt_sampler: Optional[PromptSampler] = None,
3942
):
4043
self.config = config
4144
self.evaluation_file = evaluation_file
4245
self.llm_ensemble = llm_ensemble
46+
self.prompt_sampler = prompt_sampler
4347

4448
# Create a task pool for parallel evaluation
4549
self.task_pool = TaskPool(max_concurrency=config.parallel_evaluations)
@@ -286,67 +290,66 @@ async def _llm_evaluate(self, program_code: str) -> Dict[str, float]:
286290

287291
try:
288292
# Create prompt for LLM
289-
prompt = f"""
290-
Evaluate the following code on a scale of 0.0 to 1.0 for the following metrics:
291-
1. Readability: How easy is the code to read and understand?
292-
2. Maintainability: How easy would the code be to maintain and modify?
293-
3. Efficiency: How efficient is the code in terms of time and space complexity?
294-
295-
For each metric, provide a score between 0.0 and 1.0, where 1.0 is best.
296-
297-
Code to evaluate:
298-
```python
299-
{program_code}
300-
```
301-
302-
Return your evaluation as a JSON object with the following format:
303-
{{
304-
"readability": [score],
305-
"maintainability": [score],
306-
"efficiency": [score],
307-
"reasoning": "[brief explanation of scores]"
308-
}}
309-
"""
293+
prompt = self.prompt_sampler.build_prompt(
294+
current_program=program_code, template_key="evaluation"
295+
)
310296

311297
# Get LLM response
312-
response = await self.llm_ensemble.generate(prompt)
298+
responses = await self.llm_ensemble.generate_all_with_context(
299+
prompt["system"], [{"role": "user", "content": prompt["user"]}]
300+
)
313301

314302
# Extract JSON from response
315303
try:
316304
# Try to find JSON block
317305
json_pattern = r"```json\n(.*?)\n```"
318306
import re
319307

320-
json_match = re.search(json_pattern, response, re.DOTALL)
321-
322-
if json_match:
323-
json_str = json_match.group(1)
324-
else:
325-
# Try to extract JSON directly
326-
json_str = response
327-
# Remove non-JSON parts
328-
start_idx = json_str.find("{")
329-
end_idx = json_str.rfind("}") + 1
330-
if start_idx >= 0 and end_idx > start_idx:
331-
json_str = json_str[start_idx:end_idx]
332-
333-
# Parse JSON
334-
result = json.loads(json_str)
335-
336-
# Extract metrics
337-
metrics = {}
338-
for key in ["readability", "maintainability", "efficiency"]:
339-
if key in result:
340-
metrics[key] = float(result[key])
341-
342-
return metrics
308+
avg_metrics = {}
309+
for i, response in enumerate(responses):
310+
json_match = re.search(json_pattern, response, re.DOTALL)
311+
312+
if json_match:
313+
json_str = json_match.group(1)
314+
else:
315+
# Try to extract JSON directly
316+
json_str = response
317+
# Remove non-JSON parts
318+
start_idx = json_str.find("{")
319+
end_idx = json_str.rfind("}") + 1
320+
if start_idx >= 0 and end_idx > start_idx:
321+
json_str = json_str[start_idx:end_idx]
322+
323+
# Parse JSON
324+
result = json.loads(json_str)
325+
326+
# Filter all non-numeric values
327+
metrics = {
328+
name: float(value)
329+
for name, value in result.items()
330+
if isinstance(value, (int, float))
331+
}
332+
333+
# Weight of the model in the ensemble
334+
weight = self.llm_ensemble.weights[i] if self.llm_ensemble.weights else 1.0
335+
336+
# Average the metrics
337+
for name, value in metrics.items():
338+
if name in avg_metrics:
339+
avg_metrics[name] += value * weight
340+
else:
341+
avg_metrics[name] = value * weight
342+
343+
return avg_metrics
343344

344345
except Exception as e:
345346
logger.warning(f"Error parsing LLM response: {str(e)}")
347+
traceback.print_exc()
346348
return {}
347349

348350
except Exception as e:
349351
logger.error(f"Error in LLM evaluation: {str(e)}")
352+
traceback.print_exc()
350353
return {}
351354

352355
def _passes_threshold(self, metrics: Dict[str, float], threshold: float) -> bool:

0 commit comments

Comments
 (0)