Skip to content

Commit e1e7fbc

Browse files
committed
Update task_adapter.py
1 parent 43e806c commit e1e7fbc

File tree

1 file changed

+24
-7
lines changed

1 file changed

+24
-7
lines changed

examples/algotune/task_adapter.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def _extract_task_class_info(self, task_name: str) -> Dict[str, Any]:
118118
class_info = {
119119
'name': None,
120120
'solve_method': None,
121+
'init_method': None,
121122
'generate_problem_method': None,
122123
'is_solution_method': None,
123124
'imports': [],
@@ -171,7 +172,7 @@ def _extract_task_class_info(self, task_name: str) -> Dict[str, Any]:
171172

172173
# Find the solve method using AST
173174
for item in node.body:
174-
if isinstance(item, ast.FunctionDef) and item.name == 'solve':
175+
if isinstance(item, ast.FunctionDef) and item.name in ['solve', '__init__']:
175176
try:
176177
# Get the source lines for this method
177178
method_start = item.lineno - 1 # Convert to 0-based index
@@ -219,13 +220,20 @@ def _extract_task_class_info(self, task_name: str) -> Dict[str, Any]:
219220
fixed_lines.append('')
220221
body_lines = fixed_lines
221222
if body_lines:
222-
class_info['solve_method'] = '\n'.join(body_lines)
223+
if item.name == 'solve':
224+
class_info['solve_method'] = '\n'.join(body_lines)
225+
elif item.name == '__init__':
226+
class_info['init_method'] = '\n'.join(body_lines)
223227
else:
224-
class_info['solve_method'] = ' # Placeholder for solve method\n pass'
228+
if item.name == 'solve':
229+
class_info['solve_method'] = ' # Placeholder for solve method\n pass'
230+
elif item.name == '__init__':
231+
class_info['init_method'] = ' # Placeholder for __init__ method\n pass'
225232
except Exception as e:
226-
class_info['solve_method'] = ' # Placeholder for solve method\n pass'
227-
break
228-
break
233+
if item.name == 'solve':
234+
class_info['solve_method'] = ' # Placeholder for solve method\n pass'
235+
elif item.name == '__init__':
236+
class_info['init_method'] = ' # Placeholder for __init__ method\n pass'
229237

230238
return class_info
231239

@@ -287,6 +295,15 @@ def _generate_initial_program(self, task_name: str) -> str:
287295
# Fallback to task-specific method if extraction failed
288296
method_body = self._generate_task_specific_method(task_name, solve_method, class_info)
289297

298+
# Use the actual __init__ method from the original task
299+
init_method = class_info['init_method']
300+
if init_method:
301+
# The method body is already properly indented from extraction
302+
init_method_body = init_method
303+
else:
304+
# Fallback to simple pass if extraction failed
305+
init_method_body = ' pass'
306+
290307
# Clean the description for use in docstring
291308
import re
292309
docstring_description = description.replace('\\', '\\\\')
@@ -320,7 +337,7 @@ class {class_info['name']}:
320337
321338
def __init__(self):
322339
"""Initialize the {class_info['name']}."""
323-
pass
340+
{init_method_body}
324341
325342
def solve(self, problem):
326343
"""

0 commit comments

Comments
 (0)