Skip to content

Commit 9b288fd

Browse files
committed
Update task_adapter.py
1 parent 0f59155 commit 9b288fd

File tree

1 file changed

+81
-0
lines changed

1 file changed

+81
-0
lines changed

examples/algotune/task_adapter.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,81 @@ def _extract_task_class_info(self, task_name: str) -> Dict[str, Any]:
243243

244244
return class_info
245245

246+
def _harmonize_solve_and_is_solution(self, solve_method: str, is_solution_method: str, task_name: str) -> tuple:
247+
"""
248+
Harmonize the formats between solve() and is_solution() methods.
249+
Fixes common mismatches like returning numpy arrays vs expecting lists.
250+
251+
Args:
252+
solve_method: The extracted solve method code
253+
is_solution_method: The extracted is_solution method code
254+
task_name: Name of the task for specific fixes
255+
256+
Returns:
257+
Tuple of (harmonized_solve_method, harmonized_is_solution_method)
258+
"""
259+
import re
260+
261+
# Fix common type checking issues in is_solution
262+
harmonized_is_solution = is_solution_method
263+
264+
# Replace strict list checking with flexible array/list checking
265+
if 'isinstance(proposed_list, list)' in harmonized_is_solution:
266+
harmonized_is_solution = harmonized_is_solution.replace(
267+
'isinstance(proposed_list, list)',
268+
'isinstance(proposed_list, (list, np.ndarray))'
269+
)
270+
271+
# Fix error messages to reflect the change
272+
harmonized_is_solution = harmonized_is_solution.replace(
273+
"'transformed_image' is not a list.",
274+
"'transformed_image' is not a list or array."
275+
)
276+
277+
# Add conversion logic for arrays to lists where needed
278+
if 'transformed_image' in harmonized_is_solution and task_name == 'affine_transform_2d':
279+
# For affine_transform_2d, convert arrays to lists for validation
280+
conversion_code = '''
281+
# Convert numpy array to list if needed for validation
282+
if isinstance(proposed_list, np.ndarray):
283+
proposed_list = proposed_list.tolist()
284+
'''
285+
# Insert conversion code after extracting proposed_list
286+
pattern = r'(proposed_list = solution\["transformed_image"\])'
287+
harmonized_is_solution = re.sub(
288+
pattern,
289+
r'\1' + conversion_code,
290+
harmonized_is_solution
291+
)
292+
293+
# Add similar fixes for other common patterns
294+
# Handle empty array checks
295+
if 'if proposed_list == []' in harmonized_is_solution:
296+
harmonized_is_solution = harmonized_is_solution.replace(
297+
'if proposed_list == []',
298+
'if (isinstance(proposed_list, list) and proposed_list == []) or (isinstance(proposed_list, np.ndarray) and proposed_list.size == 0)'
299+
)
300+
301+
# Fix numpy array shape mismatch issues
302+
if 'operands could not be broadcast' in task_name or task_name == 'affine_transform_2d':
303+
# Add proper array handling
304+
array_handling = '''
305+
# Ensure arrays are properly formatted
306+
if isinstance(proposed_list, np.ndarray):
307+
if proposed_list.size == 0:
308+
proposed_list = []
309+
else:
310+
proposed_list = proposed_list.tolist()
311+
'''
312+
# Insert after variable extraction
313+
if 'proposed_list = solution["transformed_image"]' in harmonized_is_solution:
314+
harmonized_is_solution = harmonized_is_solution.replace(
315+
'proposed_list = solution["transformed_image"]',
316+
'proposed_list = solution["transformed_image"]' + array_handling
317+
)
318+
319+
return solve_method, harmonized_is_solution
320+
246321
def _clean_init_method(self, init_method: str) -> str:
247322
"""
248323
Clean up extracted __init__ method body by removing docstrings and super() calls.
@@ -364,6 +439,12 @@ def _generate_initial_program(self, task_name: str) -> str:
364439
# This should be replaced with actual validation logic
365440
return True'''
366441

442+
# Harmonize solve and is_solution methods to fix format mismatches
443+
if solve_method and is_solution_method:
444+
method_body, is_solution_method_body = self._harmonize_solve_and_is_solution(
445+
method_body, is_solution_method_body, task_name
446+
)
447+
367448
# Clean the description for use in docstring
368449
import re
369450
docstring_description = description.replace('\\', '\\\\')

0 commit comments

Comments
 (0)