1+ """
2+ Tests for iteration counting and checkpoint behavior
3+ """
4+
5+ import asyncio
6+ import os
7+ import tempfile
8+ import unittest
9+ from unittest .mock import Mock , patch , MagicMock
10+
11+ # Set dummy API key for testing
12+ os .environ ["OPENAI_API_KEY" ] = "test"
13+
14+ from openevolve .config import Config
15+ from openevolve .controller import OpenEvolve
16+ from openevolve .database import Program , ProgramDatabase
17+
18+
19+ class TestIterationCounting (unittest .TestCase ):
20+ """Tests for correct iteration counting behavior"""
21+
22+ def setUp (self ):
23+ """Set up test environment"""
24+ self .test_dir = tempfile .mkdtemp ()
25+
26+ # Create test program
27+ self .program_content = """# EVOLVE-BLOCK-START
28+ def compute(x):
29+ return x * 2
30+ # EVOLVE-BLOCK-END
31+ """
32+ self .program_file = os .path .join (self .test_dir , "test_program.py" )
33+ with open (self .program_file , "w" ) as f :
34+ f .write (self .program_content )
35+
36+ # Create test evaluator
37+ self .eval_content = """
38+ def evaluate(program_path):
39+ return {"score": 0.5, "performance": 0.6}
40+ """
41+ self .eval_file = os .path .join (self .test_dir , "evaluator.py" )
42+ with open (self .eval_file , "w" ) as f :
43+ f .write (self .eval_content )
44+
45+ def tearDown (self ):
46+ """Clean up test environment"""
47+ import shutil
48+ shutil .rmtree (self .test_dir , ignore_errors = True )
49+
50+ def test_fresh_start_iteration_counting (self ):
51+ """Test that fresh start correctly handles iteration 0 as special"""
52+ # Test the logic without actually running evolution
53+ config = Config ()
54+ config .max_iterations = 20
55+ config .checkpoint_interval = 10
56+
57+ # Simulate fresh start
58+ start_iteration = 0
59+ should_add_initial = True
60+
61+ # Apply the logic from controller.py
62+ evolution_start = start_iteration
63+ evolution_iterations = config .max_iterations
64+
65+ if should_add_initial and start_iteration == 0 :
66+ evolution_start = 1
67+
68+ # Verify
69+ self .assertEqual (evolution_start , 1 , "Evolution should start at iteration 1" )
70+ self .assertEqual (evolution_iterations , 20 , "Should run 20 evolution iterations" )
71+
72+ # Simulate what process_parallel would do
73+ total_iterations = evolution_start + evolution_iterations
74+ self .assertEqual (total_iterations , 21 , "Total range should be 21 (1 through 20)" )
75+
76+ # Check checkpoint alignment
77+ expected_checkpoints = []
78+ for i in range (evolution_start , total_iterations ):
79+ if i > 0 and i % config .checkpoint_interval == 0 :
80+ expected_checkpoints .append (i )
81+
82+ self .assertEqual (expected_checkpoints , [10 , 20 ], "Checkpoints should be at 10 and 20" )
83+
84+ def test_resume_iteration_counting (self ):
85+ """Test that resume correctly continues from checkpoint"""
86+ config = Config ()
87+ config .max_iterations = 10
88+ config .checkpoint_interval = 10
89+
90+ # Simulate resume from checkpoint 10
91+ start_iteration = 11 # Last iteration was 10, so start at 11
92+ should_add_initial = False
93+
94+ # Apply the logic
95+ evolution_start = start_iteration
96+ evolution_iterations = config .max_iterations
97+
98+ if should_add_initial and start_iteration == 0 :
99+ evolution_start = 1
100+
101+ # Verify
102+ self .assertEqual (evolution_start , 11 , "Evolution should continue from iteration 11" )
103+ self .assertEqual (evolution_iterations , 10 , "Should run 10 more iterations" )
104+
105+ # Total iterations
106+ total_iterations = evolution_start + evolution_iterations
107+ self .assertEqual (total_iterations , 21 , "Should run through iteration 20" )
108+
109+ # Check checkpoint at 20
110+ expected_checkpoints = []
111+ for i in range (evolution_start , total_iterations ):
112+ if i > 0 and i % config .checkpoint_interval == 0 :
113+ expected_checkpoints .append (i )
114+
115+ self .assertEqual (expected_checkpoints , [20 ], "Should checkpoint at 20" )
116+
117+ def test_checkpoint_boundary_conditions (self ):
118+ """Test checkpoint behavior at various boundaries"""
119+ test_cases = [
120+ # (start_iter, max_iter, checkpoint_interval, expected_checkpoints)
121+ (1 , 100 , 10 , list (range (10 , 101 , 10 ))), # Standard case
122+ (1 , 99 , 10 , list (range (10 , 100 , 10 ))), # Just short of last checkpoint
123+ (1 , 101 , 10 , list (range (10 , 101 , 10 ))), # Just past checkpoint
124+ (0 , 20 , 5 , [5 , 10 , 15 , 20 ]), # Special case with iteration 0
125+ ]
126+
127+ for start , max_iter , interval , expected in test_cases :
128+ # Apply fresh start logic
129+ evolution_start = start
130+ if start == 0 :
131+ evolution_start = 1
132+
133+ total = evolution_start + max_iter
134+
135+ checkpoints = []
136+ for i in range (evolution_start , total ):
137+ if i > 0 and i % interval == 0 :
138+ checkpoints .append (i )
139+
140+ self .assertEqual (
141+ checkpoints ,
142+ expected ,
143+ f"Failed for start={ start } , max={ max_iter } , interval={ interval } "
144+ )
145+
146+ async def test_controller_iteration_behavior (self ):
147+ """Test actual controller behavior with iteration counting"""
148+ config = Config ()
149+ config .max_iterations = 20
150+ config .checkpoint_interval = 10
151+ config .database .in_memory = True
152+ config .evaluator .parallel_evaluations = 1
153+
154+ controller = OpenEvolve (
155+ initial_program_path = self .program_file ,
156+ evaluation_file = self .eval_file ,
157+ config = config ,
158+ output_dir = self .test_dir
159+ )
160+
161+ # Track checkpoint calls
162+ checkpoint_calls = []
163+ original_save = controller ._save_checkpoint
164+ controller ._save_checkpoint = lambda i : checkpoint_calls .append (i ) or original_save (i )
165+
166+ # Mock LLM
167+ with patch ('openevolve.llm.ensemble.LLMEnsemble.generate_with_context' ) as mock_llm :
168+ mock_llm .return_value = '''```python
169+ # EVOLVE-BLOCK-START
170+ def compute(x):
171+ return x << 1
172+ # EVOLVE-BLOCK-END
173+ ```'''
174+
175+ # Run with limited iterations to test
176+ await controller .run (iterations = 20 )
177+
178+ # Verify checkpoints were called correctly
179+ # Note: We expect checkpoints at 10 and 20
180+ self .assertIn (10 , checkpoint_calls , "Should checkpoint at iteration 10" )
181+ self .assertIn (20 , checkpoint_calls , "Should checkpoint at iteration 20" )
182+
183+ # Verify we have the right number of programs (initial + 20 evolution)
184+ # This may vary due to parallel execution, but should be at least 21
185+ self .assertGreaterEqual (
186+ len (controller .database .programs ),
187+ 21 ,
188+ "Should have at least 21 programs (initial + 20 iterations)"
189+ )
190+
191+
192+ if __name__ == "__main__" :
193+ # Run async test
194+ suite = unittest .TestLoader ().loadTestsFromTestCase (TestIterationCounting )
195+ runner = unittest .TextTestRunner (verbosity = 2 )
196+ result = runner .run (suite )
197+
198+ # Run the async test separately
199+ async def run_async_test ():
200+ test = TestIterationCounting ()
201+ test .setUp ()
202+ try :
203+ await test .test_controller_iteration_behavior ()
204+ print ("✓ test_controller_iteration_behavior passed" )
205+ except Exception as e :
206+ print (f"✗ test_controller_iteration_behavior failed: { e } " )
207+ finally :
208+ test .tearDown ()
209+
210+ asyncio .run (run_async_test ())
0 commit comments