Skip to content

Commit d90d47f

Browse files
authored
Merge pull request #136 from codelion/fix-pop-enforce-bug
Improve population management and reproducibility in evolution
2 parents ce4570e + 0ff351e commit d90d47f

File tree

10 files changed

+222
-181
lines changed

10 files changed

+222
-181
lines changed

openevolve/config.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class LLMModelConfig:
3232
timeout: int = None
3333
retries: int = None
3434
retry_delay: int = None
35-
35+
3636
# Reproducibility
3737
random_seed: Optional[int] = None
3838

@@ -56,10 +56,12 @@ class LLMConfig(LLMModelConfig):
5656
retry_delay: int = 5
5757

5858
# n-model configuration for evolution LLM ensemble
59-
models: List[LLMModelConfig] = field(default_factory=lambda: [
60-
LLMModelConfig(name="gpt-4o-mini", weight=0.8),
61-
LLMModelConfig(name="gpt-4o", weight=0.2)
62-
])
59+
models: List[LLMModelConfig] = field(
60+
default_factory=lambda: [
61+
LLMModelConfig(name="gpt-4o-mini", weight=0.8),
62+
LLMModelConfig(name="gpt-4o", weight=0.2),
63+
]
64+
)
6365

6466
# n-model configuration for evaluator LLM ensemble
6567
evaluator_models: List[LLMModelConfig] = field(default_factory=lambda: [])
@@ -264,7 +266,7 @@ def from_dict(cls, config_dict: Dict[str, Any]) -> "Config":
264266
config.prompt = PromptConfig(**config_dict["prompt"])
265267
if "database" in config_dict:
266268
config.database = DatabaseConfig(**config_dict["database"])
267-
269+
268270
# Ensure database inherits the random seed if not explicitly set
269271
if config.database.random_seed is None and config.random_seed is not None:
270272
config.database.random_seed = config.random_seed
@@ -365,4 +367,4 @@ def load_config(config_path: Optional[Union[str, Path]] = None) -> Config:
365367
# Make the system message available to the individual models, in case it is not provided from the prompt sampler
366368
config.llm.update_model_params({"system_message": config.prompt.system_message})
367369

368-
return config
370+
return config

openevolve/controller.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -104,20 +104,20 @@ def __init__(
104104
# Set global random seeds
105105
random.seed(self.config.random_seed)
106106
np.random.seed(self.config.random_seed)
107-
107+
108108
# Create hash-based seeds for different components
109-
base_seed = str(self.config.random_seed).encode('utf-8')
110-
llm_seed = int(hashlib.md5(base_seed + b'llm').hexdigest()[:8], 16) % (2**31)
111-
109+
base_seed = str(self.config.random_seed).encode("utf-8")
110+
llm_seed = int(hashlib.md5(base_seed + b"llm").hexdigest()[:8], 16) % (2**31)
111+
112112
# Propagate seed to LLM configurations
113113
self.config.llm.random_seed = llm_seed
114114
for model_cfg in self.config.llm.models:
115-
if not hasattr(model_cfg, 'random_seed') or model_cfg.random_seed is None:
115+
if not hasattr(model_cfg, "random_seed") or model_cfg.random_seed is None:
116116
model_cfg.random_seed = llm_seed
117117
for model_cfg in self.config.llm.evaluator_models:
118-
if not hasattr(model_cfg, 'random_seed') or model_cfg.random_seed is None:
118+
if not hasattr(model_cfg, "random_seed") or model_cfg.random_seed is None:
119119
model_cfg.random_seed = llm_seed
120-
120+
121121
logger.info(f"Set random seed to {self.config.random_seed} for reproducibility")
122122
logger.debug(f"Generated LLM seed: {llm_seed}")
123123

@@ -161,7 +161,7 @@ def __init__(
161161
self.evaluation_file = evaluation_file
162162

163163
logger.info(f"Initialized OpenEvolve with {initial_program_path}")
164-
164+
165165
# Initialize improved parallel processing components
166166
self.parallel_controller = None
167167

@@ -212,7 +212,7 @@ async def run(
212212
Best program found
213213
"""
214214
max_iterations = iterations or self.config.max_iterations
215-
215+
216216
# Determine starting iteration
217217
start_iteration = 0
218218
if checkpoint_path and os.path.exists(checkpoint_path):
@@ -260,30 +260,31 @@ async def run(
260260
self.parallel_controller = ImprovedParallelController(
261261
self.config, self.evaluation_file, self.database
262262
)
263-
263+
264264
# Set up signal handlers for graceful shutdown
265265
def signal_handler(signum, frame):
266266
logger.info(f"Received signal {signum}, initiating graceful shutdown...")
267267
self.parallel_controller.request_shutdown()
268-
268+
269269
# Set up a secondary handler for immediate exit if user presses Ctrl+C again
270270
def force_exit_handler(signum, frame):
271271
logger.info("Force exit requested - terminating immediately")
272272
import sys
273+
273274
sys.exit(0)
274-
275+
275276
signal.signal(signal.SIGINT, force_exit_handler)
276-
277+
277278
signal.signal(signal.SIGINT, signal_handler)
278279
signal.signal(signal.SIGTERM, signal_handler)
279-
280+
280281
self.parallel_controller.start()
281-
282+
282283
# Run evolution with improved parallel processing and checkpoint callback
283284
await self._run_evolution_with_checkpoints(
284285
start_iteration, max_iterations, target_score
285286
)
286-
287+
287288
finally:
288289
# Clean up parallel processing resources
289290
if self.parallel_controller:
@@ -420,31 +421,28 @@ def _load_checkpoint(self, checkpoint_path: str) -> None:
420421
"""Load state from a checkpoint directory"""
421422
if not os.path.exists(checkpoint_path):
422423
raise FileNotFoundError(f"Checkpoint directory {checkpoint_path} not found")
423-
424+
424425
logger.info(f"Loading checkpoint from {checkpoint_path}")
425426
self.database.load(checkpoint_path)
426-
logger.info(
427-
f"Checkpoint loaded successfully (iteration {self.database.last_iteration})"
428-
)
427+
logger.info(f"Checkpoint loaded successfully (iteration {self.database.last_iteration})")
429428

430429
async def _run_evolution_with_checkpoints(
431430
self, start_iteration: int, max_iterations: int, target_score: Optional[float]
432431
) -> None:
433432
"""Run evolution with checkpoint saving support"""
434433
logger.info(f"Using island-based evolution with {self.config.database.num_islands} islands")
435434
self.database.log_island_status()
436-
435+
437436
# Run the evolution process with checkpoint callback
438437
await self.parallel_controller.run_evolution(
439-
start_iteration, max_iterations, target_score,
440-
checkpoint_callback=self._save_checkpoint
438+
start_iteration, max_iterations, target_score, checkpoint_callback=self._save_checkpoint
441439
)
442-
440+
443441
# Check if shutdown was requested
444442
if self.parallel_controller.shutdown_flag.is_set():
445443
logger.info("Evolution stopped due to shutdown request")
446444
return
447-
445+
448446
# Save final checkpoint if needed
449447
final_iteration = start_iteration + max_iterations - 1
450448
if final_iteration > 0 and final_iteration % self.config.checkpoint_interval == 0:
@@ -499,4 +497,4 @@ def _save_best_program(self, program: Optional[Program] = None) -> None:
499497
indent=2,
500498
)
501499

502-
logger.info(f"Saved best program to {code_path} with program info to {info_path}")
500+
logger.info(f"Saved best program to {code_path} with program info to {info_path}")

openevolve/database.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import random
1010
import time
1111
from dataclasses import asdict, dataclass, field, fields
12+
1213
# FileLock removed - no longer needed with threaded parallel processing
1314
from typing import Any, Dict, List, Optional, Set, Tuple, Union
1415

@@ -198,7 +199,11 @@ def add(
198199
# Update archive
199200
self._update_archive(program)
200201

201-
# Update the absolute best program tracking
202+
# Enforce population size limit BEFORE updating best program tracking
203+
# This ensures newly added programs aren't immediately removed
204+
self._enforce_population_limit(exclude_program_id=program.id)
205+
206+
# Update the absolute best program tracking (after population enforcement)
202207
self._update_best_program(program)
203208

204209
# Save to disk if configured
@@ -207,9 +212,6 @@ def add(
207212

208213
logger.debug(f"Added program {program.id} to island {island_idx}")
209214

210-
# Enforce population size limit
211-
self._enforce_population_limit()
212-
213215
return program.id
214216

215217
def get(self, program_id: str) -> Optional[Program]:
@@ -254,9 +256,15 @@ def get_best_program(self, metric: Optional[str] = None) -> Optional[Program]:
254256
return None
255257

256258
# If no specific metric and we have a tracked best program, return it
257-
if metric is None and self.best_program_id and self.best_program_id in self.programs:
258-
logger.debug(f"Using tracked best program: {self.best_program_id}")
259-
return self.programs[self.best_program_id]
259+
if metric is None and self.best_program_id:
260+
if self.best_program_id in self.programs:
261+
logger.debug(f"Using tracked best program: {self.best_program_id}")
262+
return self.programs[self.best_program_id]
263+
else:
264+
logger.warning(
265+
f"Tracked best program {self.best_program_id} no longer exists, will recalculate"
266+
)
267+
self.best_program_id = None
260268

261269
if metric:
262270
# Sort by specific metric
@@ -713,7 +721,15 @@ def _update_best_program(self, program: Program) -> None:
713721
logger.debug(f"Set initial best program to {program.id}")
714722
return
715723

716-
# Compare with current best program
724+
# Compare with current best program (if it still exists)
725+
if self.best_program_id not in self.programs:
726+
logger.warning(
727+
f"Best program {self.best_program_id} no longer exists, clearing reference"
728+
)
729+
self.best_program_id = program.id
730+
logger.info(f"Set new best program to {program.id}")
731+
return
732+
717733
current_best = self.programs[self.best_program_id]
718734

719735
# Update if the new program is better
@@ -940,9 +956,12 @@ def _sample_inspirations(self, parent: Program, n: int = 5) -> List[Program]:
940956

941957
return inspirations[:n]
942958

943-
def _enforce_population_limit(self) -> None:
959+
def _enforce_population_limit(self, exclude_program_id: Optional[str] = None) -> None:
944960
"""
945961
Enforce the population size limit by removing worst programs if needed
962+
963+
Args:
964+
exclude_program_id: Program ID to never remove (e.g., newly added program)
946965
"""
947966
if len(self.programs) <= self.config.population_size:
948967
return
@@ -963,22 +982,24 @@ def _enforce_population_limit(self) -> None:
963982
key=lambda p: safe_numeric_average(p.metrics),
964983
)
965984

966-
# Remove worst programs, but never remove the best program
985+
# Remove worst programs, but never remove the best program or excluded program
967986
programs_to_remove = []
987+
protected_ids = {self.best_program_id, exclude_program_id} - {None}
988+
968989
for program in sorted_programs:
969990
if len(programs_to_remove) >= num_to_remove:
970991
break
971-
# Don't remove the best program
972-
if program.id != self.best_program_id:
992+
# Don't remove the best program or excluded program
993+
if program.id not in protected_ids:
973994
programs_to_remove.append(program)
974995

975-
# If we still need to remove more and only have the best program protected,
976-
# remove from the remaining programs anyway (but keep the absolute best)
996+
# If we still need to remove more and only have protected programs,
997+
# remove from the remaining programs anyway (but keep the protected ones)
977998
if len(programs_to_remove) < num_to_remove:
978999
remaining_programs = [
9791000
p
9801001
for p in sorted_programs
981-
if p not in programs_to_remove and p.id != self.best_program_id
1002+
if p not in programs_to_remove and p.id not in protected_ids
9821003
]
9831004
additional_removals = remaining_programs[: num_to_remove - len(programs_to_remove)]
9841005
programs_to_remove.extend(additional_removals)

openevolve/iteration.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,24 +31,21 @@ class Result:
3131
artifacts: dict = None
3232

3333

34-
35-
36-
3734
async def run_iteration_with_shared_db(
38-
iteration: int,
39-
config: Config,
35+
iteration: int,
36+
config: Config,
4037
database: ProgramDatabase,
4138
evaluator: Evaluator,
4239
llm_ensemble: LLMEnsemble,
43-
prompt_sampler: PromptSampler
40+
prompt_sampler: PromptSampler,
4441
):
4542
"""
4643
Run a single iteration using shared memory database
47-
44+
4845
This is optimized for use with persistent worker processes.
4946
"""
5047
logger = logging.getLogger(__name__)
51-
48+
5249
try:
5350
# Sample parent and inspirations from database
5451
parent, inspirations = database.sample()
@@ -115,10 +112,10 @@ async def run_iteration_with_shared_db(
115112
# Evaluate the child program
116113
child_id = str(uuid.uuid4())
117114
result.child_metrics = await evaluator.evaluate_program(child_code, child_id)
118-
115+
119116
# Handle artifacts if they exist
120117
artifacts = evaluator.get_pending_artifacts(child_id)
121-
118+
122119
# Create a child program
123120
result.child_program = Program(
124121
id=child_id,
@@ -133,13 +130,13 @@ async def run_iteration_with_shared_db(
133130
"parent_metrics": parent.metrics,
134131
},
135132
)
136-
133+
137134
result.prompt = prompt
138135
result.llm_response = llm_response
139136
result.artifacts = artifacts
140137
result.iteration_time = time.time() - iteration_start
141138
result.iteration = iteration
142-
139+
143140
return result
144141

145142
except Exception as e:

openevolve/llm/ensemble.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,22 @@ def __init__(self, models_cfg: List[LLMModelConfig]):
2727
self.weights = [model.weight for model in models_cfg]
2828
total = sum(self.weights)
2929
self.weights = [w / total for w in self.weights]
30-
30+
3131
# Set up random state for deterministic model selection
3232
self.random_state = random.Random()
3333
# Initialize with seed from first model's config if available
34-
if models_cfg and hasattr(models_cfg[0], 'random_seed') and models_cfg[0].random_seed is not None:
34+
if (
35+
models_cfg
36+
and hasattr(models_cfg[0], "random_seed")
37+
and models_cfg[0].random_seed is not None
38+
):
3539
self.random_state.seed(models_cfg[0].random_seed)
36-
logger.debug(f"LLMEnsemble: Set random seed to {models_cfg[0].random_seed} for deterministic model selection")
40+
logger.debug(
41+
f"LLMEnsemble: Set random seed to {models_cfg[0].random_seed} for deterministic model selection"
42+
)
3743

3844
# Only log if we have multiple models or this is the first ensemble
39-
if len(models_cfg) > 1 or not hasattr(logger, '_ensemble_logged'):
45+
if len(models_cfg) > 1 or not hasattr(logger, "_ensemble_logged"):
4046
logger.info(
4147
f"Initialized LLM ensemble with models: "
4248
+ ", ".join(

openevolve/llm/openai.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(
3232
self.retry_delay = model_cfg.retry_delay
3333
self.api_base = model_cfg.api_base
3434
self.api_key = model_cfg.api_key
35-
self.random_seed = getattr(model_cfg, 'random_seed', None)
35+
self.random_seed = getattr(model_cfg, "random_seed", None)
3636

3737
# Set up API client
3838
self.client = openai.OpenAI(
@@ -41,9 +41,9 @@ def __init__(
4141
)
4242

4343
# Only log unique models to reduce duplication
44-
if not hasattr(logger, '_initialized_models'):
44+
if not hasattr(logger, "_initialized_models"):
4545
logger._initialized_models = set()
46-
46+
4747
if self.model not in logger._initialized_models:
4848
logger.info(f"Initialized OpenAI LLM with model: {self.model}")
4949
logger._initialized_models.add(self.model)
@@ -80,7 +80,7 @@ async def generate_with_context(
8080
"top_p": kwargs.get("top_p", self.top_p),
8181
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
8282
}
83-
83+
8484
# Add seed parameter for reproducibility if configured
8585
# Skip seed parameter for Google AI Studio endpoint as it doesn't support it
8686
seed = kwargs.get("seed", self.random_seed)

0 commit comments

Comments
 (0)