@@ -243,6 +243,81 @@ def _extract_task_class_info(self, task_name: str) -> Dict[str, Any]:
243243
244244 return class_info
245245
246+ def _harmonize_solve_and_is_solution (self , solve_method : str , is_solution_method : str , task_name : str ) -> tuple :
247+ """
248+ Harmonize the formats between solve() and is_solution() methods.
249+ Fixes common mismatches like returning numpy arrays vs expecting lists.
250+
251+ Args:
252+ solve_method: The extracted solve method code
253+ is_solution_method: The extracted is_solution method code
254+ task_name: Name of the task for specific fixes
255+
256+ Returns:
257+ Tuple of (harmonized_solve_method, harmonized_is_solution_method)
258+ """
259+ import re
260+
261+ # Fix common type checking issues in is_solution
262+ harmonized_is_solution = is_solution_method
263+
264+ # Replace strict list checking with flexible array/list checking
265+ if 'isinstance(proposed_list, list)' in harmonized_is_solution :
266+ harmonized_is_solution = harmonized_is_solution .replace (
267+ 'isinstance(proposed_list, list)' ,
268+ 'isinstance(proposed_list, (list, np.ndarray))'
269+ )
270+
271+ # Fix error messages to reflect the change
272+ harmonized_is_solution = harmonized_is_solution .replace (
273+ "'transformed_image' is not a list." ,
274+ "'transformed_image' is not a list or array."
275+ )
276+
277+ # Add conversion logic for arrays to lists where needed
278+ if 'transformed_image' in harmonized_is_solution and task_name == 'affine_transform_2d' :
279+ # For affine_transform_2d, convert arrays to lists for validation
280+ conversion_code = '''
281+ # Convert numpy array to list if needed for validation
282+ if isinstance(proposed_list, np.ndarray):
283+ proposed_list = proposed_list.tolist()
284+ '''
285+ # Insert conversion code after extracting proposed_list
286+ pattern = r'(proposed_list = solution\["transformed_image"\])'
287+ harmonized_is_solution = re .sub (
288+ pattern ,
289+ r'\1' + conversion_code ,
290+ harmonized_is_solution
291+ )
292+
293+ # Add similar fixes for other common patterns
294+ # Handle empty array checks
295+ if 'if proposed_list == []' in harmonized_is_solution :
296+ harmonized_is_solution = harmonized_is_solution .replace (
297+ 'if proposed_list == []' ,
298+ 'if (isinstance(proposed_list, list) and proposed_list == []) or (isinstance(proposed_list, np.ndarray) and proposed_list.size == 0)'
299+ )
300+
301+ # Fix numpy array shape mismatch issues
302+ if 'operands could not be broadcast' in task_name or task_name == 'affine_transform_2d' :
303+ # Add proper array handling
304+ array_handling = '''
305+ # Ensure arrays are properly formatted
306+ if isinstance(proposed_list, np.ndarray):
307+ if proposed_list.size == 0:
308+ proposed_list = []
309+ else:
310+ proposed_list = proposed_list.tolist()
311+ '''
312+ # Insert after variable extraction
313+ if 'proposed_list = solution["transformed_image"]' in harmonized_is_solution :
314+ harmonized_is_solution = harmonized_is_solution .replace (
315+ 'proposed_list = solution["transformed_image"]' ,
316+ 'proposed_list = solution["transformed_image"]' + array_handling
317+ )
318+
319+ return solve_method , harmonized_is_solution
320+
246321 def _clean_init_method (self , init_method : str ) -> str :
247322 """
248323 Clean up extracted __init__ method body by removing docstrings and super() calls.
@@ -364,6 +439,12 @@ def _generate_initial_program(self, task_name: str) -> str:
364439 # This should be replaced with actual validation logic
365440 return True'''
366441
442+ # Harmonize solve and is_solution methods to fix format mismatches
443+ if solve_method and is_solution_method :
444+ method_body , is_solution_method_body = self ._harmonize_solve_and_is_solution (
445+ method_body , is_solution_method_body , task_name
446+ )
447+
367448 # Clean the description for use in docstring
368449 import re
369450 docstring_description = description .replace ('\\ ' , '\\ \\ ' )
0 commit comments