Skip to content

Commit 4f8ff58

Browse files
committed
as
1 parent 1fb2b70 commit 4f8ff58

File tree

4 files changed

+509
-12
lines changed

4 files changed

+509
-12
lines changed

openevolve/database.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,95 @@ def sample(self, num_inspirations: Optional[int] = None) -> Tuple[Program, List[
363363
logger.debug(f"Sampled parent {parent.id} and {len(inspirations)} inspirations")
364364
return parent, inspirations
365365

366+
def sample_from_island(
367+
self, island_id: int, num_inspirations: Optional[int] = None
368+
) -> Tuple[Program, List[Program]]:
369+
"""
370+
Sample a program and inspirations from a specific island without modifying current_island
371+
372+
This method is thread-safe and doesn't modify shared state, avoiding race conditions
373+
when multiple workers sample from different islands concurrently.
374+
375+
Args:
376+
island_id: The island to sample from
377+
num_inspirations: Number of inspiration programs to sample (defaults to 5)
378+
379+
Returns:
380+
Tuple of (parent_program, inspiration_programs)
381+
"""
382+
# Ensure valid island ID
383+
island_id = island_id % len(self.islands)
384+
385+
# Get programs from the specific island
386+
island_programs = list(self.islands[island_id])
387+
388+
if not island_programs:
389+
# Island is empty, fall back to sampling from all programs
390+
logger.debug(f"Island {island_id} is empty, sampling from all programs")
391+
return self.sample(num_inspirations)
392+
393+
# Select parent from island programs
394+
if len(island_programs) == 1:
395+
parent_id = island_programs[0]
396+
else:
397+
# Use weighted sampling based on program scores
398+
island_program_objects = [
399+
self.programs[pid] for pid in island_programs
400+
if pid in self.programs
401+
]
402+
403+
if not island_program_objects:
404+
# Fallback if programs not found
405+
parent_id = random.choice(island_programs)
406+
else:
407+
# Calculate weights based on fitness scores
408+
weights = []
409+
for prog in island_program_objects:
410+
fitness = get_fitness_score(prog.metrics, self.config.feature_dimensions)
411+
# Add small epsilon to avoid zero weights
412+
weights.append(max(fitness, 0.001))
413+
414+
# Normalize weights
415+
total_weight = sum(weights)
416+
if total_weight > 0:
417+
weights = [w / total_weight for w in weights]
418+
else:
419+
weights = [1.0 / len(island_program_objects)] * len(island_program_objects)
420+
421+
# Sample parent based on weights
422+
parent = random.choices(island_program_objects, weights=weights, k=1)[0]
423+
parent_id = parent.id
424+
425+
parent = self.programs.get(parent_id)
426+
if not parent:
427+
# Should not happen, but handle gracefully
428+
logger.error(f"Parent program {parent_id} not found in database")
429+
return self.sample(num_inspirations)
430+
431+
# Select inspirations from the same island
432+
if num_inspirations is None:
433+
num_inspirations = 5 # Default for backward compatibility
434+
435+
# Get other programs from the island for inspirations
436+
other_programs = [pid for pid in island_programs if pid != parent_id]
437+
438+
if len(other_programs) < num_inspirations:
439+
# Not enough programs in island, use what we have
440+
inspiration_ids = other_programs
441+
else:
442+
# Sample inspirations
443+
inspiration_ids = random.sample(other_programs, num_inspirations)
444+
445+
inspirations = [
446+
self.programs[pid] for pid in inspiration_ids
447+
if pid in self.programs
448+
]
449+
450+
logger.debug(
451+
f"Sampled parent {parent.id} and {len(inspirations)} inspirations from island {island_id}"
452+
)
453+
return parent, inspirations
454+
366455
def get_best_program(self, metric: Optional[str] = None) -> Optional[Program]:
367456
"""
368457
Get the best program based on a metric

openevolve/process_parallel.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -671,18 +671,12 @@ def _submit_iteration(
671671
# Use specified island or current island
672672
target_island = island_id if island_id is not None else self.database.current_island
673673

674-
# Temporarily set database to target island for sampling
675-
original_island = self.database.current_island
676-
self.database.current_island = target_island
677-
678-
try:
679-
# Sample parent and inspirations from the target island
680-
parent, inspirations = self.database.sample(
681-
num_inspirations=self.config.prompt.num_top_programs
682-
)
683-
finally:
684-
# Always restore original island state
685-
self.database.current_island = original_island
674+
# Use thread-safe sampling that doesn't modify shared state
675+
# This fixes the race condition from GitHub issue #246
676+
parent, inspirations = self.database.sample_from_island(
677+
island_id=target_island,
678+
num_inspirations=self.config.prompt.num_top_programs
679+
)
686680

687681
# Create database snapshot
688682
db_snapshot = self._create_database_snapshot()
Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
"""
2+
Test to reproduce and verify fix for GitHub issue #246
3+
Process pool termination due to concurrent island access race condition
4+
"""
5+
import unittest
6+
import tempfile
7+
import os
8+
import asyncio
9+
from concurrent.futures import ThreadPoolExecutor
10+
from unittest.mock import MagicMock, patch
11+
12+
from openevolve.database import ProgramDatabase
13+
from openevolve.config import Config
14+
from openevolve.database import Program
15+
16+
17+
class TestConcurrentIslandAccess(unittest.TestCase):
18+
"""Test concurrent access to island state in database"""
19+
20+
def setUp(self):
21+
"""Set up test database with multiple islands"""
22+
self.config = Config()
23+
self.config.database.num_islands = 5
24+
self.config.database.population_size = 100
25+
26+
# Create temporary directory for database
27+
self.temp_dir = tempfile.mkdtemp()
28+
29+
# Initialize database (only takes config parameter)
30+
self.database = ProgramDatabase(self.config.database)
31+
32+
# Add some test programs to different islands
33+
for i in range(20):
34+
program = Program(
35+
id=f"prog_{i}",
36+
code=f"def test_{i}(): return {i}",
37+
metrics={"score": i * 0.1}
38+
)
39+
# Use target_island to ensure programs go to correct islands
40+
target_island = i % 5
41+
self.database.add(program, target_island=target_island)
42+
# Verify the program has the correct island metadata
43+
program.metadata["island"] = target_island
44+
45+
def tearDown(self):
46+
"""Clean up temp directory"""
47+
import shutil
48+
shutil.rmtree(self.temp_dir, ignore_errors=True)
49+
50+
def test_concurrent_island_state_modification_causes_race_condition(self):
51+
"""
52+
Test that concurrent modifications to current_island cause issues
53+
This simulates what happens in _submit_iteration when multiple workers
54+
try to sample from different islands simultaneously
55+
"""
56+
results = []
57+
errors = []
58+
59+
def sample_from_island(island_id):
60+
"""Simulate what _submit_iteration does"""
61+
try:
62+
# This is the problematic pattern from process_parallel.py
63+
original_island = self.database.current_island
64+
self.database.current_island = island_id
65+
66+
# Simulate some work (database sampling)
67+
import time
68+
time.sleep(0.001) # Small delay to increase chance of race
69+
70+
# Try to sample
71+
try:
72+
parent, inspirations = self.database.sample(num_inspirations=2)
73+
74+
# Check if we got programs from the correct island
75+
actual_island = parent.metadata.get("island", -1)
76+
results.append({
77+
"requested_island": island_id,
78+
"actual_island": actual_island,
79+
"restored_island": original_island,
80+
"current_island_after": self.database.current_island
81+
})
82+
finally:
83+
# Restore original island (but this might be wrong due to race!)
84+
self.database.current_island = original_island
85+
86+
except Exception as e:
87+
errors.append(str(e))
88+
89+
# Run concurrent sampling from different islands
90+
with ThreadPoolExecutor(max_workers=5) as executor:
91+
futures = []
92+
# Submit 20 tasks across 5 islands
93+
for i in range(20):
94+
future = executor.submit(sample_from_island, i % 5)
95+
futures.append(future)
96+
97+
# Wait for all to complete
98+
for future in futures:
99+
future.result()
100+
101+
# Check for race condition indicators
102+
race_conditions_found = False
103+
104+
for result in results:
105+
# Check if the restored island doesn't match what we expect
106+
# This would indicate another thread modified the state
107+
if result["actual_island"] != result["requested_island"]:
108+
print(f"Race condition detected: Requested island {result['requested_island']} "
109+
f"but got program from island {result['actual_island']}")
110+
race_conditions_found = True
111+
112+
# Check if any errors occurred
113+
if errors:
114+
print(f"Errors during concurrent access: {errors}")
115+
race_conditions_found = True
116+
117+
# This test EXPECTS to find race conditions with the current implementation
118+
# After the fix, this should be changed to assertFalse
119+
if race_conditions_found:
120+
print("✅ Successfully reproduced the race condition from issue #246")
121+
else:
122+
print("⚠️ Race condition not reproduced - may need more iterations or different timing")
123+
124+
def test_sequential_island_access_works_correctly(self):
125+
"""Test that sequential access works without issues"""
126+
results = []
127+
128+
for island_id in range(5):
129+
original_island = self.database.current_island
130+
self.database.current_island = island_id
131+
132+
try:
133+
parent, inspirations = self.database.sample(num_inspirations=2)
134+
actual_island = parent.metadata.get("island", -1)
135+
results.append({
136+
"requested": island_id,
137+
"actual": actual_island
138+
})
139+
finally:
140+
self.database.current_island = original_island
141+
142+
# All sequential accesses should work correctly
143+
for result in results:
144+
self.assertEqual(
145+
result["requested"],
146+
result["actual"],
147+
f"Sequential access failed: requested {result['requested']}, got {result['actual']}"
148+
)
149+
150+
print("✅ Sequential island access works correctly")
151+
152+
def test_proposed_fix_with_island_specific_sampling(self):
153+
"""
154+
Test the proposed fix: using a method that doesn't modify shared state
155+
This simulates what the fix would look like
156+
"""
157+
# Mock the proposed sample_from_island method
158+
def sample_from_island_safe(island_id, num_inspirations=2):
159+
"""
160+
Safe sampling that doesn't modify current_island
161+
This is what we'll implement in the database
162+
"""
163+
# Get programs from specific island without changing state
164+
island_programs = list(self.database.islands[island_id])
165+
if not island_programs:
166+
# Return random program if island is empty
167+
all_programs = list(self.database.programs.values())
168+
if all_programs:
169+
import random
170+
parent = random.choice(all_programs)
171+
inspirations = random.sample(all_programs, min(num_inspirations, len(all_programs)))
172+
return parent, inspirations
173+
return None, []
174+
175+
# Sample from island programs
176+
import random
177+
parent_id = random.choice(island_programs)
178+
parent = self.database.programs.get(parent_id)
179+
180+
inspiration_ids = random.sample(
181+
island_programs,
182+
min(num_inspirations, len(island_programs))
183+
)
184+
inspirations = [
185+
self.database.programs.get(pid)
186+
for pid in inspiration_ids
187+
if pid in self.database.programs
188+
]
189+
190+
return parent, inspirations
191+
192+
# Patch the database with our safe method
193+
self.database.sample_from_island = sample_from_island_safe
194+
195+
results = []
196+
errors = []
197+
198+
def safe_sample(island_id):
199+
"""Use the safe sampling method"""
200+
try:
201+
# No state modification needed!
202+
parent, inspirations = self.database.sample_from_island(
203+
island_id,
204+
num_inspirations=2
205+
)
206+
207+
if parent:
208+
actual_island = parent.metadata.get("island", -1)
209+
results.append({
210+
"requested_island": island_id,
211+
"actual_island": actual_island,
212+
"correct": island_id == actual_island
213+
})
214+
except Exception as e:
215+
errors.append(str(e))
216+
217+
# Run concurrent sampling with the safe method
218+
with ThreadPoolExecutor(max_workers=5) as executor:
219+
futures = []
220+
for i in range(20):
221+
future = executor.submit(safe_sample, i % 5)
222+
futures.append(future)
223+
224+
for future in futures:
225+
future.result()
226+
227+
# Check results - should have no race conditions
228+
all_correct = all(r["correct"] for r in results)
229+
230+
if all_correct and not errors:
231+
print("✅ Proposed fix eliminates the race condition!")
232+
else:
233+
incorrect = [r for r in results if not r["correct"]]
234+
print(f"❌ Issues found with proposed fix: {incorrect}, errors: {errors}")
235+
236+
self.assertTrue(all_correct, "Proposed fix should eliminate race conditions")
237+
self.assertEqual(len(errors), 0, "No errors should occur with safe sampling")
238+
239+
240+
if __name__ == "__main__":
241+
# Run the tests
242+
print("Testing concurrent island access (GitHub issue #246)...\n")
243+
244+
# Create test suite
245+
suite = unittest.TestLoader().loadTestsFromTestCase(TestConcurrentIslandAccess)
246+
247+
# Run with verbose output
248+
runner = unittest.TextTestRunner(verbosity=2)
249+
result = runner.run(suite)
250+
251+
print("\n" + "="*60)
252+
if result.wasSuccessful():
253+
print("All tests passed! The issue has been identified and the fix verified.")
254+
else:
255+
print("Some tests failed. Check the output above for details.")

0 commit comments

Comments
 (0)