Skip to content

Commit 23ca50f

Browse files
committed
Update task_adapter.py
1 parent 624a640 commit 23ca50f

File tree

1 file changed

+68
-20
lines changed

1 file changed

+68
-20
lines changed

examples/algotune/task_adapter.py

Lines changed: 68 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,16 @@ def _harmonize_solve_and_is_solution(self, solve_method: str, is_solution_method
290290
harmonized_is_solution
291291
)
292292

293+
# Fix fft_convolution to return list format instead of numpy array
294+
if task_name == 'fft_convolution':
295+
# Ensure the solve method returns list format
296+
if 'convolution_result = signal.fftconvolve' in solve_method:
297+
# Replace the solution return to ensure list format
298+
solve_method = solve_method.replace(
299+
'solution = {"convolution": convolution_result}',
300+
'solution = {"convolution": convolution_result.tolist()}'
301+
)
302+
293303
# Add similar fixes for other common patterns
294304
# Handle empty array checks
295305
if 'if proposed_list == []' in harmonized_is_solution:
@@ -823,20 +833,32 @@ def evaluate(program_path, config=None):
823833
try:
824834
# Load configuration
825835
if config is None:
826-
config = {{
827-
"algotune": {{
828-
"num_trials": 5,
829-
"data_size": 100,
830-
"timeout": 3000,
831-
"num_runs": 3,
832-
"warmup_runs": 1
836+
# Try to load config from YAML file first
837+
try:
838+
import yaml
839+
from pathlib import Path
840+
config_path = Path(__file__).parent / "config.yaml"
841+
if config_path.exists():
842+
with open(config_path, 'r') as f:
843+
config = yaml.safe_load(f)
844+
else:
845+
raise FileNotFoundError("config.yaml not found")
846+
except Exception as e:
847+
# Could not load config.yaml, using defaults
848+
config = {{
849+
"algotune": {{
850+
"num_trials": 5,
851+
"data_size": 100,
852+
"timeout": 300,
853+
"num_runs": 3,
854+
"warmup_runs": 1
855+
}}
833856
}}
834-
}}
835857
836858
# Extract AlgoTune task-specific settings from config
837859
algotune_config = config.get("algotune", {{}})
838860
num_trials = algotune_config.get("num_trials", 5)
839-
data_size = algotune_config.get("data_size", 5)
861+
data_size = algotune_config.get("data_size", 100)
840862
timeout_seconds = algotune_config.get("timeout", 300)
841863
num_runs = algotune_config.get("num_runs", 3)
842864
warmup_runs = algotune_config.get("warmup_runs", 1)
@@ -1039,16 +1061,28 @@ def evaluate_stage1(program_path, config=None):
10391061
try:
10401062
# Load configuration
10411063
if config is None:
1042-
config = {{
1043-
"algotune": {{
1044-
"num_trials": 5,
1045-
"data_size": 100,
1046-
"timeout": 300
1064+
# Try to load config from YAML file first
1065+
try:
1066+
import yaml
1067+
from pathlib import Path
1068+
config_path = Path(__file__).parent / "config.yaml"
1069+
if config_path.exists():
1070+
with open(config_path, 'r') as f:
1071+
config = yaml.safe_load(f)
1072+
else:
1073+
raise FileNotFoundError("config.yaml not found")
1074+
except Exception as e:
1075+
# Could not load config.yaml, using defaults
1076+
config = {{
1077+
"algotune": {{
1078+
"num_trials": 5,
1079+
"data_size": 100,
1080+
"timeout": 300
1081+
}}
10471082
}}
1048-
}}
10491083
10501084
algotune_config = config.get("algotune", {{}})
1051-
data_size = algotune_config.get("data_size", 5)
1085+
data_size = algotune_config.get("data_size", 100)
10521086
timeout_seconds = algotune_config.get("timeout", 300)
10531087
10541088
# Load the program
@@ -1184,6 +1218,10 @@ def replace_latex_command(match):
11841218
" You will receive better scores the quicker your solution runs, and you will be penalized for exceeding the time limit or returning non-optimal solutions.\n\n"
11851219
" Below you find the description of the task you will have to solve. Read it carefully and understand what the problem is and what your solver should do.\n\n"
11861220
)
1221+
1222+
# Properly indent the description for YAML block scalar
1223+
indented_description = '\n'.join(' ' + line if line.strip() else ''
1224+
for line in clean_description.split('\n'))
11871225
config = f'''# Configuration for {task_name} task - Optimized Gemini Flash 2.5
11881226
# Achieved 1.64x AlgoTune Score with these settings
11891227
@@ -1210,10 +1248,10 @@ def replace_latex_command(match):
12101248
# Prompt Configuration - Optimal settings
12111249
prompt:
12121250
system_message: |
1213-
{system_prompt}You are an expert programmer specializing in {category} algorithms. Your task is to improve the {task_name} algorithm implementation with baseline comparison.
1251+
{system_prompt} You are an expert programmer specializing in {category} algorithms. Your task is to improve the {task_name} algorithm implementation with baseline comparison.
12141252
12151253
The problem description is:
1216-
{clean_description}
1254+
{indented_description}
12171255
12181256
Focus on improving the solve method to correctly handle the input format and produce valid solutions efficiently. Your solution will be compared against the reference AlgoTune baseline implementation to measure speedup and correctness.
12191257
num_top_programs: 3 # Best balance
@@ -1253,14 +1291,24 @@ def replace_latex_command(match):
12531291
# AlgoTune task-specific configuration
12541292
algotune:
12551293
num_trials: 5
1256-
data_size: 100
1294+
data_size: {self._get_task_data_size(task_name)}
12571295
timeout: 300
12581296
num_runs: 3
12591297
warmup_runs: 1
12601298
'''
12611299

12621300
return config
12631301

1302+
def _get_task_data_size(self, task_name: str) -> int:
1303+
"""Get task-specific data_size values."""
1304+
# Task-specific overrides for computational intensity
1305+
if task_name == "convolve2d_full_fill":
1306+
return 1 # Very computationally intensive due to 30*n × 30*n and 8*n × 8*n matrices
1307+
elif task_name == "fft_convolution":
1308+
return 10 # Moderate computational intensity
1309+
else:
1310+
return 100 # Default for all other tasks
1311+
12641312
def _generate_task_specific_method(self, task_name: str, solve_method: str, class_info: Dict[str, Any]) -> str:
12651313
"""Generate a generic fallback method when the actual solve method cannot be extracted."""
12661314

@@ -1387,4 +1435,4 @@ def get_task_info(self, task_name: str) -> Dict[str, Any]:
13871435
'description': self.get_task_description(task_name),
13881436
'path': str(self.available_tasks[task_name]['path']),
13891437
'available': True
1390-
}
1438+
}

0 commit comments

Comments
 (0)