Skip to content

Commit 352ee33

Browse files
committed
update controller
1 parent 92c7f7c commit 352ee33

File tree

1 file changed

+48
-29
lines changed

1 file changed

+48
-29
lines changed

openevolve/controller.py

Lines changed: 48 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ class OpenEvolve:
7272

7373
def __init__(
7474
self,
75-
initial_program_path: str,
75+
initial_programs_paths: List[str],
7676
evaluation_file: str,
7777
config_path: Optional[str] = None,
7878
config: Optional[Config] = None,
@@ -86,9 +86,15 @@ def __init__(
8686
# Load from file or use defaults
8787
self.config = load_config(config_path)
8888

89-
# Set up output directory
89+
# Assert that initial_programs_paths is a list, and not empty
90+
if not initial_programs_paths:
91+
raise ValueError("initial_programs_paths must be a non-empty list of file paths")
92+
93+
# Set up output directory.
94+
# If output_dir is specified, use it
95+
# Otherwise, if initial_programs_paths has a single path, use the directory of the initial program.
9096
self.output_dir = output_dir or os.path.join(
91-
os.path.dirname(initial_program_path), "openevolve_output"
97+
os.path.dirname(initial_programs_paths[0]), "openevolve_output"
9298
)
9399
os.makedirs(self.output_dir, exist_ok=True)
94100

@@ -122,20 +128,31 @@ def __init__(
122128
logger.debug(f"Generated LLM seed: {llm_seed}")
123129

124130
# Load initial program
125-
self.initial_program_path = initial_program_path
126-
self.initial_program_code = self._load_initial_program()
131+
self.initial_programs_paths = initial_programs_paths
132+
self.initial_programs_code = self._load_initial_programs()
133+
134+
# Assume all initial programs are in the same language
127135
if not self.config.language:
128-
self.config.language = extract_code_language(self.initial_program_code)
136+
self.config.language = extract_code_language(self.initial_programs_code[0])
129137

130138
# Extract file extension from initial program
131-
self.file_extension = os.path.splitext(initial_program_path)[1]
139+
self.file_extension = os.path.splitext(initial_programs_paths[0])[1]
132140
if not self.file_extension:
133141
# Default to .py if no extension found
134142
self.file_extension = ".py"
135143
else:
136144
# Make sure it starts with a dot
137145
if not self.file_extension.startswith("."):
138146
self.file_extension = f".{self.file_extension}"
147+
148+
# Check that all files have the same extension
149+
for path in initial_programs_paths[1:]:
150+
ext = os.path.splitext(path)[1]
151+
if ext != self.file_extension:
152+
raise ValueError(
153+
f"All initial program files must have the same extension. "
154+
f"Expected {self.file_extension}, but got {ext} for {path}"
155+
)
139156

140157
# Initialize components
141158
self.llm_ensemble = LLMEnsemble(self.config.llm.models)
@@ -160,7 +177,7 @@ def __init__(
160177
)
161178
self.evaluation_file = evaluation_file
162179

163-
logger.info(f"Initialized OpenEvolve with {initial_program_path}")
180+
logger.info(f"Initialized OpenEvolve with {initial_programs_paths}")
164181

165182
# Initialize improved parallel processing components
166183
self.parallel_controller = None
@@ -189,10 +206,13 @@ def _setup_logging(self) -> None:
189206

190207
logger.info(f"Logging to {log_file}")
191208

192-
def _load_initial_program(self) -> str:
193-
"""Load the initial program from file"""
194-
with open(self.initial_program_path, "r") as f:
195-
return f.read()
209+
def _load_initial_programs(self) -> str:
210+
"""Load the initial programs from file"""
211+
programs = []
212+
for path in self.initial_programs_paths:
213+
with open(path, "r") as f:
214+
programs.append(f.read())
215+
return programs
196216

197217
async def run(
198218
self,
@@ -226,29 +246,28 @@ async def run(
226246
should_add_initial = (
227247
start_iteration == 0
228248
and len(self.database.programs) == 0
229-
and not any(
230-
p.code == self.initial_program_code for p in self.database.programs.values()
231-
)
232249
)
233250

234251
if should_add_initial:
235-
logger.info("Adding initial program to database")
236-
initial_program_id = str(uuid.uuid4())
252+
logger.info("Adding initial programs to database")
253+
for code in self.initial_programs_code:
254+
initial_program_id = str(uuid.uuid4())
237255

238-
# Evaluate the initial program
239-
initial_metrics = await self.evaluator.evaluate_program(
240-
self.initial_program_code, initial_program_id
241-
)
256+
# Evaluate the initial program
257+
initial_metrics = await self.evaluator.evaluate_program(
258+
code, initial_program_id
259+
)
242260

243-
initial_program = Program(
244-
id=initial_program_id,
245-
code=self.initial_program_code,
246-
language=self.config.language,
247-
metrics=initial_metrics,
248-
iteration_found=start_iteration,
249-
)
261+
initial_program = Program(
262+
id=initial_program_id,
263+
code=code,
264+
language=self.config.language,
265+
metrics=initial_metrics,
266+
iteration_found=start_iteration,
267+
)
250268

251-
self.database.add(initial_program)
269+
# TODO. Should the island be incremented and reset here?
270+
self.database.add(initial_program)
252271
else:
253272
logger.info(
254273
f"Skipping initial program addition (resuming from iteration {start_iteration} "

0 commit comments

Comments
 (0)