Skip to content

Commit 761c880

Browse files
committed
fixes
1 parent cfef4b6 commit 761c880

File tree

4 files changed

+225
-43
lines changed

4 files changed

+225
-43
lines changed

examples/matrix_multiplication/evaluate.py

Lines changed: 68 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -137,14 +137,15 @@ def evaluate_correctness(matrix_multiply) -> float:
137137
Returns:
138138
Correctness score (0.0 to 1.0)
139139
"""
140-
# Define test cases
140+
# Define test cases focused on smaller matrices (as in the paper)
141141
test_sizes = [
142142
(2, 2, 2),
143+
(2, 3, 2),
143144
(3, 3, 3),
144-
(4, 4, 4),
145-
(10, 10, 10),
146145
(3, 4, 5),
147-
(7, 3, 8),
146+
(4, 4, 4),
147+
(4, 5, 3),
148+
(5, 5, 5),
148149
]
149150

150151
passed = 0
@@ -181,26 +182,41 @@ def evaluate_performance(matrix_multiply) -> float:
181182
Returns:
182183
Performance score (0.0 to 1.0)
183184
"""
184-
# Define benchmark sizes
185+
# Define benchmark sizes focused on smaller matrices (as in the paper)
185186
benchmark_sizes = [
186-
(10, 10, 10),
187-
(20, 20, 20),
188-
(30, 30, 30),
189-
(40, 40, 40),
187+
(2, 2, 2),
188+
(3, 3, 3),
189+
(4, 4, 4),
190+
(5, 5, 5),
191+
(3, 4, 5),
192+
(4, 3, 5),
190193
]
191194

192-
# Define baseline times (naive implementation)
193-
# These would be measured in advance for the baseline implementation
195+
# Define baseline times for the naive triple-loop implementation
196+
# These are the reference times that our initial implementation should achieve
194197
baseline_times = {
195-
"10x10x10": 0.0015, # seconds
196-
"20x20x20": 0.0120, # seconds
197-
"30x30x30": 0.0400, # seconds
198-
"40x40x40": 0.0950, # seconds
198+
"2x2x2": 0.0001,
199+
"3x3x3": 0.0003,
200+
"4x4x4": 0.0007,
201+
"5x5x5": 0.0015,
202+
"3x4x5": 0.0007,
203+
"4x3x5": 0.0007,
204+
}
205+
206+
# Define target speedups (what we're aiming for)
207+
# Based on Strassen's algorithm and other optimized approaches
208+
target_speedups = {
209+
"2x2x2": 1.5, # 50% faster than naive
210+
"3x3x3": 1.7, # 70% faster than naive
211+
"4x4x4": 2.0, # 2x faster than naive
212+
"5x5x5": 2.2, # 2.2x faster than naive
213+
"3x4x5": 1.7, # 70% faster than naive
214+
"4x3x5": 1.7, # 70% faster than naive
199215
}
200216

201217
# Run benchmark
202218
results = {}
203-
runs = 3
219+
runs = 5 # More runs for better accuracy
204220

205221
for m, n, p in benchmark_sizes:
206222
size_key = f"{m}x{n}x{p}"
@@ -221,38 +237,56 @@ def evaluate_performance(matrix_multiply) -> float:
221237
end_time = time.time()
222238
times.append(end_time - start_time)
223239

224-
# Record average time
225-
avg_time = sum(times) / runs
240+
# Record average time (remove fastest and slowest)
241+
times.sort()
242+
if len(times) > 2:
243+
times = times[1:-1] # Remove extremes
244+
avg_time = sum(times) / len(times)
226245
results[size_key] = avg_time
227246
except Exception as e:
228247
logger.warning(f"Error in performance test for sizes {(m, n, p)}: {str(e)}")
229248
results[size_key] = baseline_times[size_key] * 2 # Penalize errors
230249

231-
# Calculate speedups
250+
# Calculate speedups relative to baseline
232251
speedups = {}
233252
for size, time_taken in results.items():
234253
if time_taken > 0:
235254
speedups[size] = baseline_times[size] / time_taken
236255
else:
237256
speedups[size] = 0
238257

239-
# Calculate overall score (geometric mean of speedups)
240-
if not speedups:
241-
return 0.0
258+
# Calculate relative performance to targets
259+
target_percentages = {}
260+
for size, speedup in speedups.items():
261+
target = target_speedups[size]
262+
# If speedup is below 1.0, it's worse than baseline (score 0.0-0.2)
263+
# If speedup equals baseline, score is 0.2
264+
# If speedup is between baseline and target, score is 0.2-0.8
265+
# If speedup reaches target, score is 0.8
266+
# If speedup exceeds target, score is 0.8-1.0
267+
if speedup < 1.0:
268+
target_percentages[size] = 0.2 * speedup
269+
elif speedup < target:
270+
# Linear interpolation between 0.2 and 0.8
271+
progress = (speedup - 1.0) / (target - 1.0)
272+
target_percentages[size] = 0.2 + 0.6 * progress
273+
else:
274+
# Speedup reached or exceeded target
275+
bonus = min((speedup - target) / target, 0.5) # Cap bonus at 0.5
276+
target_percentages[size] = 0.8 + 0.2 * bonus
242277

243-
# Remove any zero speedups
244-
valid_speedups = [s for s in speedups.values() if s > 0]
245-
if not valid_speedups:
278+
# Calculate overall score (average of target percentages)
279+
if not target_percentages:
246280
return 0.0
247281

248-
# Calculate geometric mean
249-
import math
250-
log_sum = sum(math.log(s) for s in valid_speedups)
251-
geom_mean = math.exp(log_sum / len(valid_speedups))
282+
# Calculate average score
283+
avg_score = sum(target_percentages.values()) / len(target_percentages)
252284

253-
# Normalize to 0.0-1.0 range (assuming baseline = 1.0)
254-
# Values above 1.0 indicate improvement, below 1.0 indicate regression
255-
# Cap at 5.0x speedup for scoring purposes
256-
normalized_score = min(geom_mean / 5.0, 1.0)
285+
# Log detailed results for debugging
286+
logger.info(f"Performance results:")
287+
for size in benchmark_sizes:
288+
size_key = f"{size[0]}x{size[1]}x{size[2]}"
289+
if size_key in results and size_key in speedups and size_key in target_percentages:
290+
logger.info(f" {size_key}: time={results[size_key]:.6f}s, speedup={speedups[size_key]:.2f}x, score={target_percentages[size_key]:.2f}")
257291

258-
return normalized_score
292+
return avg_score

examples/matrix_multiplication/optimize.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,38 @@ async def main():
5454
config.diff_based_evolution = True
5555
config.allow_full_rewrites = False
5656

57+
# Create specialized template for matrix multiplication
58+
from openevolve.prompt.templates import TemplateManager
59+
60+
# Modify prompt templates to use specialized ones for matrix multiplication
61+
from openevolve.prompt.sampler import PromptSampler
62+
original_build_prompt = PromptSampler.build_prompt
63+
64+
def custom_build_prompt(self, *args, **kwargs):
65+
# Get template key from kwargs or use default
66+
template_key = kwargs.pop('template_key', 'diff_user') if 'template_key' in kwargs else 'diff_user'
67+
68+
# Use specialized template for matrix multiplication
69+
if template_key == 'diff_user':
70+
template_key = 'matmul_diff_user'
71+
72+
# Use specialized system message
73+
if args and len(args) >= 1:
74+
result = original_build_prompt(self, *args, **kwargs)
75+
if 'system' in result:
76+
template_manager = TemplateManager()
77+
result['system'] = template_manager.get_template('matmul_system')
78+
return result
79+
else:
80+
kwargs['template_key'] = template_key
81+
return original_build_prompt(self, *args, **kwargs)
82+
83+
# Apply the patch
84+
PromptSampler.build_prompt = custom_build_prompt
85+
86+
# Increase temperature for more creative solutions
87+
config.llm.temperature = 0.9
88+
5789
# Initialize OpenEvolve with the custom config
5890
openevolve = OpenEvolve(
5991
initial_program_path=str(initial_program_path),
@@ -65,6 +97,7 @@ async def main():
6597

6698
# Run evolution
6799
print(f"Starting evolution for {args.iterations} iterations...")
100+
print(f"Focus on optimizing matrix multiplication for small matrices (2x2 to 5x5)")
68101
best_program = await openevolve.run(iterations=args.iterations)
69102

70103
print(f"\nEvolution complete!")

openevolve/prompt/sampler.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,28 @@ def __init__(self, config: PromptConfig):
2121
# Initialize the random number generator
2222
random.seed()
2323

24+
# Store custom template mappings
25+
self.system_template_override = None
26+
self.user_template_override = None
27+
2428
logger.info("Initialized prompt sampler")
2529

30+
def set_templates(
31+
self,
32+
system_template: Optional[str] = None,
33+
user_template: Optional[str] = None
34+
) -> None:
35+
"""
36+
Set custom templates to use for this sampler
37+
38+
Args:
39+
system_template: Template name for system message
40+
user_template: Template name for user message
41+
"""
42+
self.system_template_override = system_template
43+
self.user_template_override = user_template
44+
logger.info(f"Set custom templates: system={system_template}, user={user_template}")
45+
2646
def build_prompt(
2747
self,
2848
current_program: str,
@@ -33,6 +53,7 @@ def build_prompt(
3353
language: str = "python",
3454
evolution_round: int = 0,
3555
allow_full_rewrite: bool = False,
56+
template_key: Optional[str] = None,
3657
) -> Dict[str, str]:
3758
"""
3859
Build a prompt for the LLM
@@ -46,14 +67,33 @@ def build_prompt(
4667
language: Programming language
4768
evolution_round: Current evolution round
4869
allow_full_rewrite: Whether to allow a full rewrite
70+
template_key: Optional override for template key
4971
5072
Returns:
5173
Dictionary with 'system' and 'user' keys
5274
"""
53-
# Select template based on whether we want a full rewrite
54-
template_key = "full_rewrite_user" if allow_full_rewrite else "diff_user"
55-
user_template = self.template_manager.get_template(template_key)
56-
system_template = self.config.system_message
75+
# Select template based on whether we want a full rewrite (with overrides)
76+
if template_key:
77+
# Use explicitly provided template key
78+
user_template_key = template_key
79+
elif self.user_template_override:
80+
# Use the override set with set_templates
81+
user_template_key = self.user_template_override
82+
else:
83+
# Default behavior
84+
user_template_key = "full_rewrite_user" if allow_full_rewrite else "diff_user"
85+
86+
# Get the template
87+
user_template = self.template_manager.get_template(user_template_key)
88+
89+
# Use system template override if set
90+
if self.system_template_override:
91+
system_message = self.template_manager.get_template(self.system_template_override)
92+
else:
93+
system_message = self.config.system_message
94+
# If system_message is a template name rather than content, get the template
95+
if system_message in self.template_manager.templates:
96+
system_message = self.template_manager.get_template(system_message)
5797

5898
# Format metrics
5999
metrics_str = self._format_metrics(program_metrics)
@@ -82,7 +122,7 @@ def build_prompt(
82122
)
83123

84124
return {
85-
"system": system_template,
125+
"system": system_message,
86126
"user": user_message,
87127
}
88128

openevolve/prompt/templates.py

Lines changed: 79 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,13 @@
1111
Focus on making targeted changes that will increase the program's performance metrics.
1212
"""
1313

14+
# Matrix multiplication system template
15+
MATMUL_SYSTEM_TEMPLATE = """You are an expert algorithm engineer specialized in numerical computing and matrix operations.
16+
Your task is to optimize matrix multiplication algorithms for better performance while maintaining correctness.
17+
Apply techniques like loop reordering, blocking, recursion, and mathematical insights to reduce the number of operations.
18+
Focus on making improvements for smaller matrix sizes (2x2 to 5x5) where algorithmic innovations like Strassen's algorithm can make a difference.
19+
"""
20+
1421
# User message template for diff-based evolution
1522
DIFF_USER_TEMPLATE = """# Current Program Information
1623
- Current performance metrics: {metrics}
@@ -26,20 +33,86 @@
2633
2734
# Task
2835
Suggest improvements to the program that will lead to better performance on the specified metrics.
29-
Use the SEARCH/REPLACE diff format to indicate changes:
36+
37+
You MUST use the exact SEARCH/REPLACE diff format shown below to indicate changes:
3038
3139
<<<<<<< SEARCH
32-
# Code to find and replace
40+
# Original code to find and replace (must match exactly)
3341
=======
3442
# New replacement code
3543
>>>>>>> REPLACE
3644
37-
You can suggest multiple changes. Make sure each SEARCH section exactly matches code in the current program.
38-
Be thoughtful about your changes and explain your reasoning.
45+
Example of valid diff format:
46+
<<<<<<< SEARCH
47+
for i in range(m):
48+
for j in range(p):
49+
for k in range(n):
50+
C[i, j] += A[i, k] * B[k, j]
51+
=======
52+
# Reorder loops for better memory access pattern
53+
for i in range(m):
54+
for k in range(n):
55+
for j in range(p):
56+
C[i, j] += A[i, k] * B[k, j]
57+
>>>>>>> REPLACE
58+
59+
You can suggest multiple changes. Each SEARCH section must exactly match code in the current program.
60+
Be thoughtful about your changes and explain your reasoning thoroughly.
3961
4062
IMPORTANT: Do not rewrite the entire program - focus on targeted improvements.
4163
"""
4264

65+
# Matrix multiplication specific template
66+
MATMUL_DIFF_USER_TEMPLATE = """# Matrix Multiplication Optimization Task
67+
- Current performance metrics: {metrics}
68+
- Areas identified for improvement: {improvement_areas}
69+
70+
# Program Evolution History
71+
{evolution_history}
72+
73+
# Current Program
74+
```{language}
75+
{current_program}
76+
```
77+
78+
# Task
79+
Optimize the matrix multiplication algorithm for better performance while maintaining correctness.
80+
Focus on smaller matrix sizes (2x2 to 5x5) where algorithmic innovations can make a significant difference.
81+
82+
Consider these optimization strategies:
83+
1. Loop reordering for better cache locality
84+
2. Loop unrolling to reduce loop overhead
85+
3. Blocking/tiling for better memory access patterns
86+
4. Algorithmic improvements like Strassen's algorithm for recursive decomposition
87+
5. Special case handling for specific matrix sizes
88+
6. Vectorization hints and SIMD-friendly operations
89+
90+
You MUST use the exact SEARCH/REPLACE diff format shown below to indicate changes:
91+
92+
<<<<<<< SEARCH
93+
# Original code to find and replace (must match exactly)
94+
=======
95+
# New replacement code
96+
>>>>>>> REPLACE
97+
98+
Example of valid diff format:
99+
<<<<<<< SEARCH
100+
for i in range(m):
101+
for j in range(p):
102+
for k in range(n):
103+
C[i, j] += A[i, k] * B[k, j]
104+
=======
105+
# Reorder loops for better memory access pattern
106+
for i in range(m):
107+
for k in range(n):
108+
for j in range(p):
109+
C[i, j] += A[i, k] * B[k, j]
110+
>>>>>>> REPLACE
111+
112+
You can suggest multiple changes. Each SEARCH section must exactly match code in the current program.
113+
Explain the reasoning behind your optimizations.
114+
"""
115+
43116
# User message template for full rewrite
44117
FULL_REWRITE_USER_TEMPLATE = """# Current Program Information
45118
- Current performance metrics: {metrics}
@@ -93,7 +166,9 @@
93166
# Default templates dictionary
94167
DEFAULT_TEMPLATES = {
95168
"system_message": BASE_SYSTEM_TEMPLATE,
169+
"matmul_system": MATMUL_SYSTEM_TEMPLATE,
96170
"diff_user": DIFF_USER_TEMPLATE,
171+
"matmul_diff_user": MATMUL_DIFF_USER_TEMPLATE,
97172
"full_rewrite_user": FULL_REWRITE_USER_TEMPLATE,
98173
"evolution_history": EVOLUTION_HISTORY_TEMPLATE,
99174
"previous_attempt": PREVIOUS_ATTEMPT_TEMPLATE,

0 commit comments

Comments
 (0)