Skip to content

Commit f6a715b

Browse files
committed
Update task_adapter.py
1 parent 3f2a048 commit f6a715b

File tree

1 file changed

+43
-4
lines changed

1 file changed

+43
-4
lines changed

examples/algotune/task_adapter.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,9 @@ def _extract_task_class_info(self, task_name: str) -> Dict[str, Any]:
170170
# Extract the entire class code
171171
class_info['class_code'] = ast.unparse(node)
172172

173-
# Find the solve method using AST
173+
# Find the solve, __init__, and is_solution methods using AST
174174
for item in node.body:
175-
if isinstance(item, ast.FunctionDef) and item.name in ['solve', '__init__']:
175+
if isinstance(item, ast.FunctionDef) and item.name in ['solve', '__init__', 'is_solution']:
176176
try:
177177
# Get the source lines for this method
178178
method_start = item.lineno - 1 # Convert to 0-based index
@@ -224,16 +224,22 @@ def _extract_task_class_info(self, task_name: str) -> Dict[str, Any]:
224224
class_info['solve_method'] = '\n'.join(body_lines)
225225
elif item.name == '__init__':
226226
class_info['init_method'] = '\n'.join(body_lines)
227+
elif item.name == 'is_solution':
228+
class_info['is_solution_method'] = '\n'.join(body_lines)
227229
else:
228230
if item.name == 'solve':
229231
class_info['solve_method'] = ' # Placeholder for solve method\n pass'
230232
elif item.name == '__init__':
231233
class_info['init_method'] = ' # Placeholder for __init__ method\n pass'
234+
elif item.name == 'is_solution':
235+
class_info['is_solution_method'] = ' # Placeholder for is_solution method\n pass'
232236
except Exception as e:
233237
if item.name == 'solve':
234238
class_info['solve_method'] = ' # Placeholder for solve method\n pass'
235239
elif item.name == '__init__':
236240
class_info['init_method'] = ' # Placeholder for __init__ method\n pass'
241+
elif item.name == 'is_solution':
242+
class_info['is_solution_method'] = ' # Placeholder for is_solution method\n pass'
237243

238244
return class_info
239245

@@ -346,6 +352,18 @@ def _generate_initial_program(self, task_name: str) -> str:
346352
# Fallback to simple pass if extraction failed
347353
init_method_body = ' pass'
348354

355+
# Use the actual is_solution method from the original task
356+
is_solution_method = class_info['is_solution_method']
357+
if is_solution_method:
358+
# The method body is already properly indented from extraction
359+
is_solution_method_body = is_solution_method
360+
else:
361+
# Fallback method if extraction failed
362+
is_solution_method_body = ''' """Check if the provided solution is valid."""
363+
# Placeholder validation - always returns True
364+
# This should be replaced with actual validation logic
365+
return True'''
366+
349367
# Clean the description for use in docstring
350368
import re
351369
docstring_description = description.replace('\\', '\\\\')
@@ -397,6 +415,24 @@ def solve(self, problem):
397415
except Exception as e:
398416
logging.error(f"Error in solve method: {{e}}")
399417
raise e
418+
419+
def is_solution(self, problem, solution):
420+
"""
421+
Check if the provided solution is valid.
422+
423+
Args:
424+
problem: The original problem
425+
solution: The proposed solution
426+
427+
Returns:
428+
True if the solution is valid, False otherwise
429+
"""
430+
try:
431+
{is_solution_method_body}
432+
433+
except Exception as e:
434+
logging.error(f"Error in is_solution method: {{e}}")
435+
return False
400436
401437
def run_solver(problem):
402438
"""
@@ -804,10 +840,13 @@ def evaluate(program_path, config=None):
804840
evolved_solution = safe_convert(evolved_solution)
805841
806842
try:
807-
is_valid = task_instance.is_solution(problem, evolved_solution)
843+
# Use the evolved program's own is_solution method for validation
844+
# This ensures consistency between the extracted solve and validation logic
845+
evolved_solver = program.{class_info['name']}()
846+
is_valid = evolved_solver.is_solution(problem, evolved_solution)
808847
correctness_score = 1.0 if is_valid else 0.0
809848
except Exception as e:
810-
print(f"Trial {{trial}}: Error checking solution validity: {{e}}")
849+
print(f"Trial {{trial}}: Error checking solution validity with evolved is_solution: {{e}}")
811850
correctness_score = 0.0
812851
is_valid = False
813852

0 commit comments

Comments
 (0)