1+ """
2+ Unit tests for feature_stats persistence in ProgramDatabase checkpoints
3+ """
4+
5+ import json
6+ import os
7+ import tempfile
8+ import shutil
9+ import unittest
10+ from unittest .mock import patch
11+
12+ from openevolve .database import ProgramDatabase , Program
13+ from openevolve .config import DatabaseConfig
14+
15+
16+ class TestFeatureStatsPersistence (unittest .TestCase ):
17+ """Test feature_stats are correctly saved and loaded in checkpoints"""
18+
19+ def setUp (self ):
20+ """Set up test environment"""
21+ self .test_dir = tempfile .mkdtemp ()
22+ self .config = DatabaseConfig (
23+ db_path = self .test_dir ,
24+ feature_dimensions = ["score" , "custom_metric1" , "custom_metric2" ],
25+ feature_bins = 10
26+ )
27+
28+ def tearDown (self ):
29+ """Clean up test environment"""
30+ shutil .rmtree (self .test_dir )
31+
32+ def test_feature_stats_saved_and_loaded (self ):
33+ """Test that feature_stats are correctly saved and loaded from checkpoints"""
34+ # Create database and add programs to build feature_stats
35+ db1 = ProgramDatabase (self .config )
36+
37+ programs = []
38+ for i in range (5 ):
39+ program = Program (
40+ id = f"test_prog_{ i } " ,
41+ code = f"# Test program { i } " ,
42+ metrics = {
43+ "combined_score" : 0.1 + i * 0.2 ,
44+ "custom_metric1" : 10 + i * 20 ,
45+ "custom_metric2" : 100 + i * 50
46+ }
47+ )
48+ programs .append (program )
49+ db1 .add (program )
50+
51+ # Verify feature_stats were built
52+ self .assertIn ("score" , db1 .feature_stats )
53+ self .assertIn ("custom_metric1" , db1 .feature_stats )
54+ self .assertIn ("custom_metric2" , db1 .feature_stats )
55+
56+ # Store original feature_stats for comparison
57+ original_stats = {
58+ dim : {
59+ "min" : stats ["min" ],
60+ "max" : stats ["max" ],
61+ "values" : stats ["values" ].copy ()
62+ }
63+ for dim , stats in db1 .feature_stats .items ()
64+ }
65+
66+ # Save checkpoint
67+ db1 .save (self .test_dir , iteration = 42 )
68+
69+ # Load into new database
70+ db2 = ProgramDatabase (self .config )
71+ db2 .load (self .test_dir )
72+
73+ # Verify feature_stats were loaded correctly
74+ self .assertEqual (len (db2 .feature_stats ), len (original_stats ))
75+
76+ for dim , original in original_stats .items ():
77+ self .assertIn (dim , db2 .feature_stats )
78+ loaded = db2 .feature_stats [dim ]
79+
80+ self .assertAlmostEqual (loaded ["min" ], original ["min" ], places = 5 )
81+ self .assertAlmostEqual (loaded ["max" ], original ["max" ], places = 5 )
82+ self .assertEqual (loaded ["values" ], original ["values" ])
83+
84+ def test_empty_feature_stats_handling (self ):
85+ """Test handling of empty feature_stats"""
86+ db1 = ProgramDatabase (self .config )
87+
88+ # Save without any programs (empty feature_stats)
89+ db1 .save (self .test_dir , iteration = 1 )
90+
91+ # Load and verify
92+ db2 = ProgramDatabase (self .config )
93+ db2 .load (self .test_dir )
94+
95+ self .assertEqual (db2 .feature_stats , {})
96+
97+ def test_backward_compatibility_missing_feature_stats (self ):
98+ """Test loading checkpoints that don't have feature_stats (backward compatibility)"""
99+ # Create a checkpoint manually without feature_stats
100+ os .makedirs (self .test_dir , exist_ok = True )
101+
102+ # Create metadata without feature_stats (simulating old checkpoint)
103+ metadata = {
104+ "feature_map" : {},
105+ "islands" : [[]],
106+ "archive" : [],
107+ "best_program_id" : None ,
108+ "island_best_programs" : [None ],
109+ "last_iteration" : 10 ,
110+ "current_island" : 0 ,
111+ "island_generations" : [0 ],
112+ "last_migration_generation" : 0 ,
113+ # Note: no "feature_stats" key
114+ }
115+
116+ with open (os .path .join (self .test_dir , "metadata.json" ), "w" ) as f :
117+ json .dump (metadata , f )
118+
119+ # Load should work without errors
120+ db = ProgramDatabase (self .config )
121+ db .load (self .test_dir )
122+
123+ # feature_stats should be empty but not None
124+ self .assertEqual (db .feature_stats , {})
125+
126+ def test_feature_stats_serialization_edge_cases (self ):
127+ """Test feature_stats serialization handles edge cases correctly"""
128+ db = ProgramDatabase (self .config )
129+
130+ # Test with various edge cases
131+ db .feature_stats = {
132+ "normal_case" : {
133+ "min" : 1.0 ,
134+ "max" : 10.0 ,
135+ "values" : [1.0 , 5.0 , 10.0 ]
136+ },
137+ "single_value" : {
138+ "min" : 5.0 ,
139+ "max" : 5.0 ,
140+ "values" : [5.0 ]
141+ },
142+ "large_values_list" : {
143+ "min" : 0.0 ,
144+ "max" : 200.0 ,
145+ "values" : list (range (200 )) # Should be truncated to 100
146+ },
147+ "empty_values" : {
148+ "min" : 0.0 ,
149+ "max" : 1.0 ,
150+ "values" : []
151+ }
152+ }
153+
154+ # Test serialization
155+ serialized = db ._serialize_feature_stats ()
156+
157+ # Check that large values list was truncated
158+ self .assertLessEqual (len (serialized ["large_values_list" ]["values" ]), 100 )
159+
160+ # Test deserialization
161+ deserialized = db ._deserialize_feature_stats (serialized )
162+
163+ # Verify structure is maintained
164+ self .assertIn ("normal_case" , deserialized )
165+ self .assertIn ("single_value" , deserialized )
166+ self .assertIn ("large_values_list" , deserialized )
167+ self .assertIn ("empty_values" , deserialized )
168+
169+ # Verify types are correct
170+ for dim , stats in deserialized .items ():
171+ self .assertIsInstance (stats ["min" ], float )
172+ self .assertIsInstance (stats ["max" ], float )
173+ self .assertIsInstance (stats ["values" ], list )
174+
175+ def test_feature_stats_preservation_during_load (self ):
176+ """Test that feature_stats ranges are preserved when loading from checkpoint"""
177+ # Create database with programs
178+ db1 = ProgramDatabase (self .config )
179+
180+ test_programs = []
181+
182+ for i in range (3 ):
183+ program = Program (
184+ id = f"stats_test_{ i } " ,
185+ code = f"# Stats test { i } " ,
186+ metrics = {
187+ "combined_score" : 0.2 + i * 0.3 ,
188+ "custom_metric1" : 20 + i * 30 ,
189+ "custom_metric2" : 200 + i * 100
190+ }
191+ )
192+ test_programs .append (program )
193+ db1 .add (program )
194+
195+ # Record original feature ranges
196+ original_ranges = {}
197+ for dim , stats in db1 .feature_stats .items ():
198+ original_ranges [dim ] = {
199+ "min" : stats ["min" ],
200+ "max" : stats ["max" ]
201+ }
202+
203+ # Save checkpoint
204+ db1 .save (self .test_dir , iteration = 50 )
205+
206+ # Load into new database
207+ db2 = ProgramDatabase (self .config )
208+ db2 .load (self .test_dir )
209+
210+ # Verify feature ranges are preserved
211+ for dim , original_range in original_ranges .items ():
212+ self .assertIn (dim , db2 .feature_stats )
213+ loaded_stats = db2 .feature_stats [dim ]
214+
215+ self .assertAlmostEqual (
216+ loaded_stats ["min" ], original_range ["min" ], places = 5 ,
217+ msg = f"Min value changed for { dim } : { original_range ['min' ]} -> { loaded_stats ['min' ]} "
218+ )
219+ self .assertAlmostEqual (
220+ loaded_stats ["max" ], original_range ["max" ], places = 5 ,
221+ msg = f"Max value changed for { dim } : { original_range ['max' ]} -> { loaded_stats ['max' ]} "
222+ )
223+
224+ # Test that adding a new program within existing ranges doesn't break anything
225+ new_program = Program (
226+ id = "range_test" ,
227+ code = "# Program to test range stability" ,
228+ metrics = {
229+ "combined_score" : 0.35 , # Within existing range
230+ "custom_metric1" : 35 , # Within existing range
231+ "custom_metric2" : 250 # Within existing range
232+ }
233+ )
234+
235+ # Adding this program should not cause issues
236+ db2 .add (new_program )
237+ new_coords = db2 ._calculate_feature_coords (new_program )
238+
239+ # Should get valid coordinates
240+ self .assertEqual (len (new_coords ), len (self .config .feature_dimensions ))
241+ for coord in new_coords :
242+ self .assertIsInstance (coord , int )
243+ self .assertGreaterEqual (coord , 0 )
244+
245+ def test_feature_stats_with_numpy_types (self ):
246+ """Test that numpy types are correctly handled in serialization"""
247+ import numpy as np
248+
249+ db = ProgramDatabase (self .config )
250+
251+ # Simulate feature_stats with numpy types
252+ db .feature_stats = {
253+ "numpy_test" : {
254+ "min" : np .float64 (1.5 ),
255+ "max" : np .float64 (9.5 ),
256+ "values" : [np .float64 (x ) for x in [1.5 , 5.0 , 9.5 ]]
257+ }
258+ }
259+
260+ # Test serialization doesn't fail
261+ serialized = db ._serialize_feature_stats ()
262+
263+ # Verify numpy types were converted to Python types
264+ self .assertIsInstance (serialized ["numpy_test" ]["min" ], float )
265+ self .assertIsInstance (serialized ["numpy_test" ]["max" ], float )
266+
267+ # Test deserialization
268+ deserialized = db ._deserialize_feature_stats (serialized )
269+ self .assertIsInstance (deserialized ["numpy_test" ]["min" ], float )
270+ self .assertIsInstance (deserialized ["numpy_test" ]["max" ], float )
271+
272+ def test_malformed_feature_stats_handling (self ):
273+ """Test handling of malformed feature_stats during deserialization"""
274+ db = ProgramDatabase (self .config )
275+
276+ # Test with malformed data
277+ malformed_data = {
278+ "valid_entry" : {
279+ "min" : 1.0 ,
280+ "max" : 10.0 ,
281+ "values" : [1.0 , 5.0 , 10.0 ]
282+ },
283+ "invalid_entry" : "this is not a dict" ,
284+ "missing_keys" : {
285+ "min" : 1.0
286+ # missing "max" and "values"
287+ }
288+ }
289+
290+ with patch ('openevolve.database.logger' ) as mock_logger :
291+ deserialized = db ._deserialize_feature_stats (malformed_data )
292+
293+ # Should have valid entry and skip invalid ones
294+ self .assertIn ("valid_entry" , deserialized )
295+ self .assertNotIn ("invalid_entry" , deserialized )
296+ self .assertIn ("missing_keys" , deserialized ) # Should be created with defaults
297+
298+ # Should have logged warning for invalid entry
299+ mock_logger .warning .assert_called ()
300+
301+
302+ if __name__ == "__main__" :
303+ unittest .main ()
0 commit comments