|
6 | 6 | import os |
7 | 7 | import tempfile |
8 | 8 | import unittest |
9 | | -from unittest.mock import AsyncMock, MagicMock, patch |
| 9 | +from unittest.mock import AsyncMock, MagicMock, Mock, patch |
10 | 10 | import json |
11 | 11 | import time |
12 | 12 |
|
@@ -96,10 +96,16 @@ async def run_test(): |
96 | 96 | self.assertEqual(len(controller.database.programs), 0) |
97 | 97 | self.assertEqual(controller.database.last_iteration, 0) |
98 | 98 |
|
99 | | - # Mock the LLM to avoid actual API calls |
100 | | - with patch.object(controller.llm_ensemble, "generate_with_context") as mock_llm: |
101 | | - mock_llm.return_value = "No changes needed" |
102 | | - |
| 99 | + # Mock the parallel controller to avoid API calls |
| 100 | + with patch("openevolve.controller.ImprovedParallelController") as mock_controller_class: |
| 101 | + mock_controller = Mock() |
| 102 | + mock_controller.run_evolution = AsyncMock(return_value=None) |
| 103 | + mock_controller.start = Mock(return_value=None) |
| 104 | + mock_controller.stop = Mock(return_value=None) |
| 105 | + mock_controller.shutdown_flag = Mock() |
| 106 | + mock_controller.shutdown_flag.is_set.return_value = False |
| 107 | + mock_controller_class.return_value = mock_controller |
| 108 | + |
103 | 109 | # Run for 0 iterations (just initialization) |
104 | 110 | result = await controller.run(iterations=0) |
105 | 111 |
|
@@ -144,10 +150,16 @@ async def run_test(): |
144 | 150 |
|
145 | 151 | controller.database.add(existing_program) |
146 | 152 |
|
147 | | - # Mock the LLM to avoid actual API calls |
148 | | - with patch.object(controller.llm_ensemble, "generate_with_context") as mock_llm: |
149 | | - mock_llm.return_value = "No changes needed" |
150 | | - |
| 153 | + # Mock the parallel controller to avoid API calls |
| 154 | + with patch("openevolve.controller.ImprovedParallelController") as mock_controller_class: |
| 155 | + mock_controller = Mock() |
| 156 | + mock_controller.run_evolution = AsyncMock(return_value=None) |
| 157 | + mock_controller.start = Mock(return_value=None) |
| 158 | + mock_controller.stop = Mock(return_value=None) |
| 159 | + mock_controller.shutdown_flag = Mock() |
| 160 | + mock_controller.shutdown_flag.is_set.return_value = False |
| 161 | + mock_controller_class.return_value = mock_controller |
| 162 | + |
151 | 163 | # Run for 0 iterations (just initialization) |
152 | 164 | result = await controller.run(iterations=0) |
153 | 165 |
|
@@ -191,10 +203,16 @@ async def run_test(): |
191 | 203 | self.assertEqual(len(controller.database.programs), 1) |
192 | 204 | self.assertEqual(controller.database.last_iteration, 10) |
193 | 205 |
|
194 | | - # Mock the LLM to avoid actual API calls |
195 | | - with patch.object(controller.llm_ensemble, "generate_with_context") as mock_llm: |
196 | | - mock_llm.return_value = "No changes needed" |
197 | | - |
| 206 | + # Mock the parallel controller to avoid API calls |
| 207 | + with patch("openevolve.controller.ImprovedParallelController") as mock_controller_class: |
| 208 | + mock_controller = Mock() |
| 209 | + mock_controller.run_evolution = AsyncMock(return_value=None) |
| 210 | + mock_controller.start = Mock(return_value=None) |
| 211 | + mock_controller.stop = Mock(return_value=None) |
| 212 | + mock_controller.shutdown_flag = Mock() |
| 213 | + mock_controller.shutdown_flag.is_set.return_value = False |
| 214 | + mock_controller_class.return_value = mock_controller |
| 215 | + |
198 | 216 | # Run for 0 iterations (just initialization) |
199 | 217 | result = await controller.run(iterations=0) |
200 | 218 |
|
@@ -241,10 +259,16 @@ async def run_test(): |
241 | 259 | self.assertEqual(len(controller.database.programs), 1) |
242 | 260 | self.assertEqual(controller.database.last_iteration, 0) |
243 | 261 |
|
244 | | - # Mock the LLM to avoid actual API calls |
245 | | - with patch.object(controller.llm_ensemble, "generate_with_context") as mock_llm: |
246 | | - mock_llm.return_value = "No changes needed" |
247 | | - |
| 262 | + # Mock the parallel controller to avoid API calls |
| 263 | + with patch("openevolve.controller.ImprovedParallelController") as mock_controller_class: |
| 264 | + mock_controller = Mock() |
| 265 | + mock_controller.run_evolution = AsyncMock(return_value=None) |
| 266 | + mock_controller.start = Mock(return_value=None) |
| 267 | + mock_controller.stop = Mock(return_value=None) |
| 268 | + mock_controller.shutdown_flag = Mock() |
| 269 | + mock_controller.shutdown_flag.is_set.return_value = False |
| 270 | + mock_controller_class.return_value = mock_controller |
| 271 | + |
248 | 272 | # Run for 0 iterations (just initialization) |
249 | 273 | result = await controller.run(iterations=0) |
250 | 274 |
|
@@ -275,9 +299,9 @@ async def run_test(): |
275 | 299 | output_dir=self.test_dir, |
276 | 300 | ) |
277 | 301 |
|
278 | | - # Mock the LLM to avoid actual API calls |
279 | | - with patch.object(controller.llm_ensemble, "generate_with_context") as mock_llm: |
280 | | - mock_llm.return_value = "No changes needed" |
| 302 | + # Mock the parallel controller to avoid API calls |
| 303 | + with patch.object(controller, "parallel_controller") as mock_parallel: |
| 304 | + mock_parallel.run_evolution = AsyncMock(return_value=None) |
281 | 305 |
|
282 | 306 | # Run first time |
283 | 307 | result1 = await controller.run(iterations=0) |
|
0 commit comments