Skip to content

Commit 4c6535e

Browse files
committed
fix feature stats persistance
1 parent 102d419 commit 4c6535e

File tree

3 files changed

+671
-0
lines changed

3 files changed

+671
-0
lines changed

openevolve/database.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,7 @@ def save(self, path: Optional[str] = None, iteration: int = 0) -> None:
486486
"current_island": self.current_island,
487487
"island_generations": self.island_generations,
488488
"last_migration_generation": self.last_migration_generation,
489+
"feature_stats": self._serialize_feature_stats(),
489490
}
490491

491492
with open(os.path.join(save_path, "metadata.json"), "w") as f:
@@ -522,8 +523,13 @@ def load(self, path: str) -> None:
522523
self.current_island = metadata.get("current_island", 0)
523524
self.island_generations = metadata.get("island_generations", [0] * len(saved_islands))
524525
self.last_migration_generation = metadata.get("last_migration_generation", 0)
526+
527+
# Load feature_stats for MAP-Elites grid stability
528+
self.feature_stats = self._deserialize_feature_stats(metadata.get("feature_stats", {}))
525529

526530
logger.info(f"Loaded database metadata with last_iteration={self.last_iteration}")
531+
if self.feature_stats:
532+
logger.info(f"Loaded feature_stats for {len(self.feature_stats)} dimensions")
527533

528534
# Load programs
529535
programs_dir = os.path.join(path, "programs")
@@ -1815,6 +1821,62 @@ def _scale_feature_value_minmax(self, feature_name: str, value: float) -> float:
18151821
scaled = (value - min_val) / (max_val - min_val)
18161822
return min(1.0, max(0.0, scaled))
18171823

1824+
def _serialize_feature_stats(self) -> Dict[str, Any]:
1825+
"""
1826+
Serialize feature_stats for JSON storage
1827+
1828+
Returns:
1829+
Dictionary that can be JSON-serialized
1830+
"""
1831+
serialized = {}
1832+
for feature_name, stats in self.feature_stats.items():
1833+
# Convert to JSON-serializable format
1834+
serialized_stats = {}
1835+
for key, value in stats.items():
1836+
if key == "values":
1837+
# Limit size to prevent excessive memory usage
1838+
# Keep only the most recent 100 values for percentile calculations
1839+
if isinstance(value, list) and len(value) > 100:
1840+
serialized_stats[key] = value[-100:]
1841+
else:
1842+
serialized_stats[key] = value
1843+
else:
1844+
# Convert numpy types to Python native types
1845+
if hasattr(value, 'item'): # numpy scalar
1846+
serialized_stats[key] = value.item()
1847+
else:
1848+
serialized_stats[key] = value
1849+
serialized[feature_name] = serialized_stats
1850+
return serialized
1851+
1852+
def _deserialize_feature_stats(self, stats_dict: Dict[str, Any]) -> Dict[str, Dict[str, Union[float, List[float]]]]:
1853+
"""
1854+
Deserialize feature_stats from loaded JSON
1855+
1856+
Args:
1857+
stats_dict: Dictionary loaded from JSON
1858+
1859+
Returns:
1860+
Properly formatted feature_stats dictionary
1861+
"""
1862+
if not stats_dict:
1863+
return {}
1864+
1865+
deserialized = {}
1866+
for feature_name, stats in stats_dict.items():
1867+
if isinstance(stats, dict):
1868+
# Ensure proper structure and types
1869+
deserialized_stats = {
1870+
"min": float(stats.get("min", 0.0)),
1871+
"max": float(stats.get("max", 1.0)),
1872+
"values": list(stats.get("values", [])),
1873+
}
1874+
deserialized[feature_name] = deserialized_stats
1875+
else:
1876+
logger.warning(f"Skipping malformed feature_stats entry for '{feature_name}': {stats}")
1877+
1878+
return deserialized
1879+
18181880
def log_island_status(self) -> None:
18191881
"""Log current status of all islands"""
18201882
stats = self.get_island_stats()
Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
"""
2+
Unit tests for feature_stats persistence in ProgramDatabase checkpoints
3+
"""
4+
5+
import json
6+
import os
7+
import tempfile
8+
import shutil
9+
import unittest
10+
from unittest.mock import patch
11+
12+
from openevolve.database import ProgramDatabase, Program
13+
from openevolve.config import DatabaseConfig
14+
15+
16+
class TestFeatureStatsPersistence(unittest.TestCase):
17+
"""Test feature_stats are correctly saved and loaded in checkpoints"""
18+
19+
def setUp(self):
20+
"""Set up test environment"""
21+
self.test_dir = tempfile.mkdtemp()
22+
self.config = DatabaseConfig(
23+
db_path=self.test_dir,
24+
feature_dimensions=["score", "custom_metric1", "custom_metric2"],
25+
feature_bins=10
26+
)
27+
28+
def tearDown(self):
29+
"""Clean up test environment"""
30+
shutil.rmtree(self.test_dir)
31+
32+
def test_feature_stats_saved_and_loaded(self):
33+
"""Test that feature_stats are correctly saved and loaded from checkpoints"""
34+
# Create database and add programs to build feature_stats
35+
db1 = ProgramDatabase(self.config)
36+
37+
programs = []
38+
for i in range(5):
39+
program = Program(
40+
id=f"test_prog_{i}",
41+
code=f"# Test program {i}",
42+
metrics={
43+
"combined_score": 0.1 + i * 0.2,
44+
"custom_metric1": 10 + i * 20,
45+
"custom_metric2": 100 + i * 50
46+
}
47+
)
48+
programs.append(program)
49+
db1.add(program)
50+
51+
# Verify feature_stats were built
52+
self.assertIn("score", db1.feature_stats)
53+
self.assertIn("custom_metric1", db1.feature_stats)
54+
self.assertIn("custom_metric2", db1.feature_stats)
55+
56+
# Store original feature_stats for comparison
57+
original_stats = {
58+
dim: {
59+
"min": stats["min"],
60+
"max": stats["max"],
61+
"values": stats["values"].copy()
62+
}
63+
for dim, stats in db1.feature_stats.items()
64+
}
65+
66+
# Save checkpoint
67+
db1.save(self.test_dir, iteration=42)
68+
69+
# Load into new database
70+
db2 = ProgramDatabase(self.config)
71+
db2.load(self.test_dir)
72+
73+
# Verify feature_stats were loaded correctly
74+
self.assertEqual(len(db2.feature_stats), len(original_stats))
75+
76+
for dim, original in original_stats.items():
77+
self.assertIn(dim, db2.feature_stats)
78+
loaded = db2.feature_stats[dim]
79+
80+
self.assertAlmostEqual(loaded["min"], original["min"], places=5)
81+
self.assertAlmostEqual(loaded["max"], original["max"], places=5)
82+
self.assertEqual(loaded["values"], original["values"])
83+
84+
def test_empty_feature_stats_handling(self):
85+
"""Test handling of empty feature_stats"""
86+
db1 = ProgramDatabase(self.config)
87+
88+
# Save without any programs (empty feature_stats)
89+
db1.save(self.test_dir, iteration=1)
90+
91+
# Load and verify
92+
db2 = ProgramDatabase(self.config)
93+
db2.load(self.test_dir)
94+
95+
self.assertEqual(db2.feature_stats, {})
96+
97+
def test_backward_compatibility_missing_feature_stats(self):
98+
"""Test loading checkpoints that don't have feature_stats (backward compatibility)"""
99+
# Create a checkpoint manually without feature_stats
100+
os.makedirs(self.test_dir, exist_ok=True)
101+
102+
# Create metadata without feature_stats (simulating old checkpoint)
103+
metadata = {
104+
"feature_map": {},
105+
"islands": [[]],
106+
"archive": [],
107+
"best_program_id": None,
108+
"island_best_programs": [None],
109+
"last_iteration": 10,
110+
"current_island": 0,
111+
"island_generations": [0],
112+
"last_migration_generation": 0,
113+
# Note: no "feature_stats" key
114+
}
115+
116+
with open(os.path.join(self.test_dir, "metadata.json"), "w") as f:
117+
json.dump(metadata, f)
118+
119+
# Load should work without errors
120+
db = ProgramDatabase(self.config)
121+
db.load(self.test_dir)
122+
123+
# feature_stats should be empty but not None
124+
self.assertEqual(db.feature_stats, {})
125+
126+
def test_feature_stats_serialization_edge_cases(self):
127+
"""Test feature_stats serialization handles edge cases correctly"""
128+
db = ProgramDatabase(self.config)
129+
130+
# Test with various edge cases
131+
db.feature_stats = {
132+
"normal_case": {
133+
"min": 1.0,
134+
"max": 10.0,
135+
"values": [1.0, 5.0, 10.0]
136+
},
137+
"single_value": {
138+
"min": 5.0,
139+
"max": 5.0,
140+
"values": [5.0]
141+
},
142+
"large_values_list": {
143+
"min": 0.0,
144+
"max": 200.0,
145+
"values": list(range(200)) # Should be truncated to 100
146+
},
147+
"empty_values": {
148+
"min": 0.0,
149+
"max": 1.0,
150+
"values": []
151+
}
152+
}
153+
154+
# Test serialization
155+
serialized = db._serialize_feature_stats()
156+
157+
# Check that large values list was truncated
158+
self.assertLessEqual(len(serialized["large_values_list"]["values"]), 100)
159+
160+
# Test deserialization
161+
deserialized = db._deserialize_feature_stats(serialized)
162+
163+
# Verify structure is maintained
164+
self.assertIn("normal_case", deserialized)
165+
self.assertIn("single_value", deserialized)
166+
self.assertIn("large_values_list", deserialized)
167+
self.assertIn("empty_values", deserialized)
168+
169+
# Verify types are correct
170+
for dim, stats in deserialized.items():
171+
self.assertIsInstance(stats["min"], float)
172+
self.assertIsInstance(stats["max"], float)
173+
self.assertIsInstance(stats["values"], list)
174+
175+
def test_feature_stats_preservation_during_load(self):
176+
"""Test that feature_stats ranges are preserved when loading from checkpoint"""
177+
# Create database with programs
178+
db1 = ProgramDatabase(self.config)
179+
180+
test_programs = []
181+
182+
for i in range(3):
183+
program = Program(
184+
id=f"stats_test_{i}",
185+
code=f"# Stats test {i}",
186+
metrics={
187+
"combined_score": 0.2 + i * 0.3,
188+
"custom_metric1": 20 + i * 30,
189+
"custom_metric2": 200 + i * 100
190+
}
191+
)
192+
test_programs.append(program)
193+
db1.add(program)
194+
195+
# Record original feature ranges
196+
original_ranges = {}
197+
for dim, stats in db1.feature_stats.items():
198+
original_ranges[dim] = {
199+
"min": stats["min"],
200+
"max": stats["max"]
201+
}
202+
203+
# Save checkpoint
204+
db1.save(self.test_dir, iteration=50)
205+
206+
# Load into new database
207+
db2 = ProgramDatabase(self.config)
208+
db2.load(self.test_dir)
209+
210+
# Verify feature ranges are preserved
211+
for dim, original_range in original_ranges.items():
212+
self.assertIn(dim, db2.feature_stats)
213+
loaded_stats = db2.feature_stats[dim]
214+
215+
self.assertAlmostEqual(
216+
loaded_stats["min"], original_range["min"], places=5,
217+
msg=f"Min value changed for {dim}: {original_range['min']} -> {loaded_stats['min']}"
218+
)
219+
self.assertAlmostEqual(
220+
loaded_stats["max"], original_range["max"], places=5,
221+
msg=f"Max value changed for {dim}: {original_range['max']} -> {loaded_stats['max']}"
222+
)
223+
224+
# Test that adding a new program within existing ranges doesn't break anything
225+
new_program = Program(
226+
id="range_test",
227+
code="# Program to test range stability",
228+
metrics={
229+
"combined_score": 0.35, # Within existing range
230+
"custom_metric1": 35, # Within existing range
231+
"custom_metric2": 250 # Within existing range
232+
}
233+
)
234+
235+
# Adding this program should not cause issues
236+
db2.add(new_program)
237+
new_coords = db2._calculate_feature_coords(new_program)
238+
239+
# Should get valid coordinates
240+
self.assertEqual(len(new_coords), len(self.config.feature_dimensions))
241+
for coord in new_coords:
242+
self.assertIsInstance(coord, int)
243+
self.assertGreaterEqual(coord, 0)
244+
245+
def test_feature_stats_with_numpy_types(self):
246+
"""Test that numpy types are correctly handled in serialization"""
247+
import numpy as np
248+
249+
db = ProgramDatabase(self.config)
250+
251+
# Simulate feature_stats with numpy types
252+
db.feature_stats = {
253+
"numpy_test": {
254+
"min": np.float64(1.5),
255+
"max": np.float64(9.5),
256+
"values": [np.float64(x) for x in [1.5, 5.0, 9.5]]
257+
}
258+
}
259+
260+
# Test serialization doesn't fail
261+
serialized = db._serialize_feature_stats()
262+
263+
# Verify numpy types were converted to Python types
264+
self.assertIsInstance(serialized["numpy_test"]["min"], float)
265+
self.assertIsInstance(serialized["numpy_test"]["max"], float)
266+
267+
# Test deserialization
268+
deserialized = db._deserialize_feature_stats(serialized)
269+
self.assertIsInstance(deserialized["numpy_test"]["min"], float)
270+
self.assertIsInstance(deserialized["numpy_test"]["max"], float)
271+
272+
def test_malformed_feature_stats_handling(self):
273+
"""Test handling of malformed feature_stats during deserialization"""
274+
db = ProgramDatabase(self.config)
275+
276+
# Test with malformed data
277+
malformed_data = {
278+
"valid_entry": {
279+
"min": 1.0,
280+
"max": 10.0,
281+
"values": [1.0, 5.0, 10.0]
282+
},
283+
"invalid_entry": "this is not a dict",
284+
"missing_keys": {
285+
"min": 1.0
286+
# missing "max" and "values"
287+
}
288+
}
289+
290+
with patch('openevolve.database.logger') as mock_logger:
291+
deserialized = db._deserialize_feature_stats(malformed_data)
292+
293+
# Should have valid entry and skip invalid ones
294+
self.assertIn("valid_entry", deserialized)
295+
self.assertNotIn("invalid_entry", deserialized)
296+
self.assertIn("missing_keys", deserialized) # Should be created with defaults
297+
298+
# Should have logged warning for invalid entry
299+
mock_logger.warning.assert_called()
300+
301+
302+
if __name__ == "__main__":
303+
unittest.main()

0 commit comments

Comments
 (0)