Skip to content

Commit c6371fc

Browse files
committed
fix
1 parent f4d158b commit c6371fc

File tree

2 files changed

+102
-32
lines changed

2 files changed

+102
-32
lines changed

examples/matrix_multiplication/evaluate.py

Lines changed: 77 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import logging
99
import sys
1010
import time
11+
import traceback
1112
from typing import Dict, List, Tuple, Any
1213

1314
import numpy as np
@@ -49,6 +50,17 @@ def evaluate(program_path: str) -> Dict[str, float]:
4950
Returns:
5051
Dictionary of metric name to score
5152
"""
53+
# First perform basic validation
54+
stage1_result = evaluate_stage1(program_path)
55+
56+
# If validation fails, return early
57+
if stage1_result["correctness"] < 0.8:
58+
return {
59+
"correctness": stage1_result["correctness"],
60+
"rank_quality": 0.0,
61+
"time_efficiency": 0.0
62+
}
63+
5264
# Import the program
5365
try:
5466
spec = importlib.util.spec_from_file_location("program_module", program_path)
@@ -97,14 +109,42 @@ def evaluate(program_path: str) -> Dict[str, float]:
97109

98110
def evaluate_stage1(program_path: str) -> Dict[str, float]:
99111
"""
100-
First stage of evaluation: test correctness
112+
First stage of evaluation: basic validation and test correctness
101113
102114
Args:
103115
program_path: Path to the program file
104116
105117
Returns:
106118
Dictionary of metric name to score
107119
"""
120+
# First, perform static code analysis and basic validation
121+
try:
122+
with open(program_path, 'r') as f:
123+
code_content = f.read()
124+
125+
# Basic syntax check
126+
try:
127+
compile(code_content, program_path, 'exec')
128+
except SyntaxError as e:
129+
logger.error(f"Syntax error in program: {str(e)}")
130+
return {"correctness": 0.0}
131+
132+
# Check for common issues
133+
if "TensorDecomposition" not in code_content:
134+
logger.error("Program does not contain 'TensorDecomposition' class")
135+
return {"correctness": 0.0}
136+
137+
# Look for variable reference issues (e.g., 'u_factors' being used before definition)
138+
if "u_factors" in code_content and "_initialize_decomposition" in code_content:
139+
# Very basic check - not exhaustive but catches simple issues
140+
if code_content.find("u_factors") < code_content.find("_initialize_decomposition"):
141+
if "def _initialize_decomposition" in code_content:
142+
logger.error("Possible reference to 'u_factors' before initialization")
143+
return {"correctness": 0.0}
144+
except Exception as e:
145+
logger.error(f"Error during static code validation: {str(e)}")
146+
return {"correctness": 0.0}
147+
108148
# Import the program
109149
try:
110150
spec = importlib.util.spec_from_file_location("program_module", program_path)
@@ -113,20 +153,50 @@ def evaluate_stage1(program_path: str) -> Dict[str, float]:
113153

114154
module = importlib.util.module_from_spec(spec)
115155
sys.modules["program_module"] = module
116-
spec.loader.exec_module(module)
117156

157+
# Use a safety wrapper to catch any import-time errors
158+
try:
159+
spec.loader.exec_module(module)
160+
except Exception as e:
161+
logger.error(f"Error during module execution: {str(e)}")
162+
traceback.print_exc()
163+
return {"correctness": 0.0}
164+
165+
# Check for the required class
118166
if not hasattr(module, "TensorDecomposition"):
119-
raise AttributeError(f"Program does not contain a 'TensorDecomposition' class")
167+
logger.error("Program does not contain a 'TensorDecomposition' class")
168+
return {"correctness": 0.0}
120169

170+
# Check basic class structure
121171
TensorDecomposition = module.TensorDecomposition
172+
required_methods = ["__init__", "optimize", "_initialize_decomposition"]
173+
for method in required_methods:
174+
if not hasattr(TensorDecomposition, method):
175+
logger.error(f"TensorDecomposition class missing required method: {method}")
176+
return {"correctness": 0.0}
177+
178+
# Try to instantiate the class with minimal parameters
179+
try:
180+
test_instance = TensorDecomposition(target_shape=(2, 2, 2), rank=7)
181+
except Exception as e:
182+
logger.error(f"Failed to instantiate TensorDecomposition: {str(e)}")
183+
traceback.print_exc()
184+
return {"correctness": 0.0}
185+
122186
except Exception as e:
123187
logger.error(f"Error importing program: {str(e)}")
188+
traceback.print_exc()
124189
return {"correctness": 0.0}
125190

126-
# Test correctness
127-
correctness_score = evaluate_correctness(TensorDecomposition)
128-
129-
return {"correctness": correctness_score}
191+
# If we get here, the basic validation passed
192+
# Now perform a simple correctness test
193+
try:
194+
correctness_score = evaluate_correctness(TensorDecomposition)
195+
return {"correctness": correctness_score}
196+
except Exception as e:
197+
logger.error(f"Error in correctness evaluation: {str(e)}")
198+
traceback.print_exc()
199+
return {"correctness": 0.0}
130200

131201

132202
def evaluate_correctness(TensorDecomposition) -> float:

examples/matrix_multiplication/optimize.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -77,55 +77,55 @@ async def main():
7777

7878
# System message focusing on tensor decomposition and optimization
7979
system_template = """You are an expert in computational mathematics, numerical optimization, and algorithm design.
80-
Your task is to optimize a tensor decomposition algorithm for discovering efficient matrix multiplication.
80+
Your task is to carefully optimize a tensor decomposition algorithm for discovering efficient matrix multiplication.
8181
8282
When matrix multiplication is viewed as a tensor problem, the goal is to find a minimum-rank decomposition
8383
of the corresponding 3D tensor. Each term in the decomposition corresponds to a scalar multiplication in
8484
the algorithm, so minimizing the rank directly leads to faster matrix multiplication.
8585
86-
Your focus should be on targeted, specific improvements to the tensor decomposition optimization process.
87-
Make small, focused changes rather than rewriting large sections. Key areas to consider include:
86+
Your focus should be on making ONE targeted, specific improvement. Code quality is critical - buggy
87+
code will fail evaluation and waste computing resources. Pay special attention to:
8888
89-
1. Loss function improvements (regularization terms that encourage specific properties)
90-
2. Adding noise injection or annealing schedules to avoid local minima
91-
3. Improving initialization strategies for better convergence
92-
4. Numerical stability enhancements for complex-valued operations
93-
5. Techniques to encourage integer or half-integer coefficients
89+
1. Variable scope and references - don't use variables before they're defined
90+
2. Make only small, focused changes rather than large rewrites
91+
3. Test your change carefully, ensuring it will work when added to the existing code
92+
4. Respect the class structure and function interfaces
9493
95-
The best algorithms from the literature for various matrix sizes include:
96-
- 2x2 matrices: Strassen's algorithm (7 multiplications)
97-
- 3x3 matrices: Laderman's algorithm (23 multiplications)
98-
- 4x4 matrices: Recursive Strassen (49 multiplications, improved to 48 with complex values)
94+
The best matrix multiplication algorithms in the literature include Strassen's algorithm (7 multiplications
95+
for 2x2) and Laderman's algorithm (23 multiplications for 3x3).
9996
100-
Focus on making 1-3 specific, high-impact changes rather than comprehensive rewrites.
97+
Make only ONE specific, high-impact change rather than multiple modifications.
10198
"""
10299

103-
# User message template for tensor decomposition optimization (shortened for compatibility)
100+
# User message template for tensor decomposition optimization (focused and clear)
104101
user_template = """
105-
Improve the tensor decomposition algorithm below with 1-3 specific changes.
102+
Focus on fixing ONE specific issue or making ONE targeted improvement in this tensor decomposition algorithm.
106103
107-
CURRENT CODE:
104+
CODE:
108105
{current_program}
109106
110107
METRICS:
111108
{metrics}
112109
113-
FOCUS AREAS:
114-
{improvement_areas}
110+
FOCUS AREA: Make ONE change to improve the algorithm, focusing on either:
111+
1. The loss function to guide optimization better
112+
2. Initialization strategy for faster convergence
113+
3. Adding regularization for integer/half-integer coefficients
115114
116-
Make targeted changes using this format:
115+
Use SEARCH/REPLACE (keep the search section SHORT):
117116
118117
<<<<<<< SEARCH
119-
// exact code to match (keep short)
118+
// exact code to match (small section only)
120119
=======
121120
// improved code
122121
>>>>>>> REPLACE
123122
124-
RULES:
125-
1. Each SEARCH block must exactly match existing code
126-
2. Focus on 1-3 specific changes only
127-
3. Explain each change briefly
128-
4. Avoid large rewrites
123+
IMPORTANT RULES:
124+
1. Make only ONE change and explain it briefly
125+
2. Ensure the SEARCH block EXACTLY matches existing code
126+
3. NEVER use variables before they're defined
127+
4. Don't reference u_factors, v_factors, or w_factors outside _initialize_decomposition, _loss_fn, etc.
128+
5. ALWAYS test your change mentally to ensure it would work
129129
"""
130130

131131
# Add the templates to OpenEvolve's template manager

0 commit comments

Comments
 (0)