Skip to content

Commit 08c5b80

Browse files
committed
f
1 parent 7355f2f commit 08c5b80

File tree

4 files changed

+216
-225
lines changed

4 files changed

+216
-225
lines changed

examples/mlx_metal_kernel_opt/README.md

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# 🎯 Qwen3-0.6B Custom Metal Kernel Optimization with OpenEvolve
1+
# 🎯Custom Metal Kernel Optimization with OpenEvolve
22

33
**Evolving custom GPU kernels for Grouped Query Attention using MLX Metal kernels for Qwen3-0.6B on Apple Silicon**
44

@@ -416,29 +416,3 @@ python run_benchmarks.py --mode compare
416416
---
417417

418418
**🎯 This example demonstrates OpenEvolve's capability to discover genuine algorithmic improvements through evolutionary optimization, achieving measurable performance gains on real hardware with production-ready implementations.**
419-
420-
## 🔧 **Recent Improvements**
421-
422-
### **✅ Correct Terminology**
423-
- **Before**: Incorrect references to "chunked GQA processing"
424-
- **After**: Accurate descriptions of custom Metal kernel optimization
425-
- **Benefits**: Technical accuracy and clear understanding of actual discoveries
426-
427-
### **✅ Comprehensive Testing**
428-
- **Before**: Basic performance measurement
429-
- **After**: 17-scenario comprehensive benchmark suite with statistical validation
430-
- **Benefits**: Robust performance analysis and reproducible results
431-
432-
### **✅ Production Integration**
433-
- **Before**: Standalone optimization experiments
434-
- **After**: Full MLX-LM integration with seamless switching
435-
- **Benefits**: Real-world usability and easy adoption
436-
437-
### **✅ Detailed Documentation**
438-
- **Before**: High-level optimization descriptions
439-
- **After**: Complete technical details with actual kernel code snippets
440-
- **Benefits**: Understanding, reproducibility, and further research
441-
442-
---
443-
444-
**🚀 Ready for custom Metal kernel evolution with comprehensive benchmarking and detailed analysis!**

openevolve/evaluator.py

Lines changed: 68 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,20 @@ async def evaluate_program(
134134
# Process the result based on type
135135
eval_result = self._process_evaluation_result(result)
136136

137+
# Check if this was a timeout and capture artifacts if enabled
138+
if artifacts_enabled and program_id and eval_result.metrics.get("timeout") is True:
139+
if program_id not in self._pending_artifacts:
140+
self._pending_artifacts[program_id] = {}
141+
142+
self._pending_artifacts[program_id].update(
143+
{
144+
"timeout": True,
145+
"timeout_duration": self.config.timeout,
146+
"failure_stage": "evaluation",
147+
"error_type": "timeout",
148+
}
149+
)
150+
137151
# Add LLM feedback if configured
138152
llm_eval_result = None
139153
if self.config.use_llm_feedback and self.llm_ensemble:
@@ -153,7 +167,8 @@ async def evaluate_program(
153167
)
154168
and program_id
155169
):
156-
self._pending_artifacts[program_id] = {}
170+
if program_id not in self._pending_artifacts:
171+
self._pending_artifacts[program_id] = {}
157172

158173
# Merge eval_result artifacts with llm artifacts if they exist
159174
if eval_result.has_artifacts():
@@ -179,6 +194,21 @@ async def evaluate_program(
179194
# Return just metrics for backward compatibility
180195
return eval_result.metrics
181196

197+
except asyncio.TimeoutError:
198+
# Handle timeout specially - don't retry, just return timeout result
199+
logger.warning(f"Evaluation timed out after {self.config.timeout}s")
200+
201+
# Capture timeout artifacts if enabled
202+
if artifacts_enabled and program_id:
203+
self._pending_artifacts[program_id] = {
204+
"timeout": True,
205+
"timeout_duration": self.config.timeout,
206+
"failure_stage": "evaluation",
207+
"error_type": "timeout",
208+
}
209+
210+
return {"error": 0.0, "timeout": True}
211+
182212
except Exception as e:
183213
last_exception = e
184214
logger.warning(
@@ -192,6 +222,7 @@ async def evaluate_program(
192222
"stderr": str(e),
193223
"traceback": traceback.format_exc(),
194224
"failure_stage": "evaluation",
225+
"attempt": attempt + 1,
195226
}
196227

197228
# If this is not the last attempt, wait a bit before retrying
@@ -251,30 +282,27 @@ async def _direct_evaluate(self, program_path: str) -> Dict[str, float]:
251282
252283
Returns:
253284
Dictionary of metric name to score
254-
"""
255-
try:
256-
# Create a coroutine that runs the evaluation function in an executor
257-
async def run_evaluation():
258-
loop = asyncio.get_event_loop()
259-
return await loop.run_in_executor(None, self.evaluate_function, program_path)
260285
261-
# Run the evaluation with timeout
262-
result = await asyncio.wait_for(run_evaluation(), timeout=self.config.timeout)
286+
Raises:
287+
asyncio.TimeoutError: If evaluation exceeds timeout
288+
Exception: If evaluation function raises an exception
289+
"""
263290

264-
# Validate result
265-
if not isinstance(result, dict):
266-
logger.warning(f"Evaluation returned non-dictionary result: {result}")
267-
return {"error": 0.0}
291+
# Create a coroutine that runs the evaluation function in an executor
292+
async def run_evaluation():
293+
loop = asyncio.get_event_loop()
294+
return await loop.run_in_executor(None, self.evaluate_function, program_path)
268295

269-
return result
296+
# Run the evaluation with timeout - let exceptions bubble up for retry handling
297+
result = await asyncio.wait_for(run_evaluation(), timeout=self.config.timeout)
270298

271-
except asyncio.TimeoutError:
272-
logger.warning(f"Evaluation timed out after {self.config.timeout}s")
273-
return {"error": 0.0, "timeout": True}
274-
except Exception as e:
275-
logger.error(f"Error in direct evaluation: {str(e)}")
299+
# Validate result
300+
if not isinstance(result, dict):
301+
logger.warning(f"Evaluation returned non-dictionary result: {result}")
276302
return {"error": 0.0}
277303

304+
return result
305+
278306
async def _cascade_evaluate(
279307
self, program_path: str
280308
) -> Union[Dict[str, float], EvaluationResult]:
@@ -286,6 +314,10 @@ async def _cascade_evaluate(
286314
287315
Returns:
288316
Dictionary of metrics or EvaluationResult with metrics and artifacts
317+
318+
Raises:
319+
asyncio.TimeoutError: If any stage exceeds timeout
320+
Exception: If any evaluation stage raises an exception
289321
"""
290322
# Import the evaluation module to get cascade functions if they exist
291323
try:
@@ -307,34 +339,12 @@ async def _cascade_evaluate(
307339
return await self._direct_evaluate(program_path)
308340

309341
# Run first stage with timeout
310-
try:
342+
async def run_stage1():
343+
loop = asyncio.get_event_loop()
344+
return await loop.run_in_executor(None, module.evaluate_stage1, program_path)
311345

312-
async def run_stage1():
313-
loop = asyncio.get_event_loop()
314-
return await loop.run_in_executor(None, module.evaluate_stage1, program_path)
315-
316-
stage1_result = await asyncio.wait_for(run_stage1(), timeout=self.config.timeout)
317-
stage1_eval_result = self._process_evaluation_result(stage1_result)
318-
except asyncio.TimeoutError:
319-
logger.warning(f"Stage 1 evaluation timed out after {self.config.timeout}s")
320-
return EvaluationResult(
321-
metrics={"stage1_passed": 0.0, "error": 0.0, "timeout": True},
322-
artifacts={
323-
"failure_stage": "stage1",
324-
"timeout": True,
325-
},
326-
)
327-
except Exception as e:
328-
logger.error(f"Error in stage 1 evaluation: {str(e)}")
329-
# Capture stage 1 failure as artifacts
330-
return EvaluationResult(
331-
metrics={"stage1_passed": 0.0, "error": 0.0},
332-
artifacts={
333-
"stderr": str(e),
334-
"traceback": traceback.format_exc(),
335-
"failure_stage": "stage1",
336-
},
337-
)
346+
stage1_result = await asyncio.wait_for(run_stage1(), timeout=self.config.timeout)
347+
stage1_eval_result = self._process_evaluation_result(stage1_result)
338348

339349
# Check threshold
340350
if not self._passes_threshold(
@@ -347,38 +357,12 @@ async def run_stage1():
347357
return stage1_eval_result
348358

349359
# Run second stage with timeout
350-
try:
351-
352-
async def run_stage2():
353-
loop = asyncio.get_event_loop()
354-
return await loop.run_in_executor(None, module.evaluate_stage2, program_path)
360+
async def run_stage2():
361+
loop = asyncio.get_event_loop()
362+
return await loop.run_in_executor(None, module.evaluate_stage2, program_path)
355363

356-
stage2_result = await asyncio.wait_for(run_stage2(), timeout=self.config.timeout)
357-
stage2_eval_result = self._process_evaluation_result(stage2_result)
358-
except asyncio.TimeoutError:
359-
logger.warning(f"Stage 2 evaluation timed out after {self.config.timeout}s")
360-
# Capture stage 2 failure, but keep stage 1 results
361-
stage1_eval_result.artifacts.update(
362-
{
363-
"stage2_timeout": True,
364-
"failure_stage": "stage2",
365-
}
366-
)
367-
stage1_eval_result.metrics["stage2_passed"] = 0.0
368-
stage1_eval_result.metrics["timeout"] = True
369-
return stage1_eval_result
370-
except Exception as e:
371-
logger.error(f"Error in stage 2 evaluation: {str(e)}")
372-
# Capture stage 2 failure, but keep stage 1 results
373-
stage1_eval_result.artifacts.update(
374-
{
375-
"stage2_stderr": str(e),
376-
"stage2_traceback": traceback.format_exc(),
377-
"failure_stage": "stage2",
378-
}
379-
)
380-
stage1_eval_result.metrics["stage2_passed"] = 0.0
381-
return stage1_eval_result
364+
stage2_result = await asyncio.wait_for(run_stage2(), timeout=self.config.timeout)
365+
stage2_eval_result = self._process_evaluation_result(stage2_result)
382366

383367
# Merge results from stage 1 and 2
384368
merged_metrics = {}
@@ -409,38 +393,12 @@ async def run_stage2():
409393
return merged_result
410394

411395
# Run third stage with timeout
412-
try:
396+
async def run_stage3():
397+
loop = asyncio.get_event_loop()
398+
return await loop.run_in_executor(None, module.evaluate_stage3, program_path)
413399

414-
async def run_stage3():
415-
loop = asyncio.get_event_loop()
416-
return await loop.run_in_executor(None, module.evaluate_stage3, program_path)
417-
418-
stage3_result = await asyncio.wait_for(run_stage3(), timeout=self.config.timeout)
419-
stage3_eval_result = self._process_evaluation_result(stage3_result)
420-
except asyncio.TimeoutError:
421-
logger.warning(f"Stage 3 evaluation timed out after {self.config.timeout}s")
422-
# Capture stage 3 failure, but keep previous results
423-
merged_result.artifacts.update(
424-
{
425-
"stage3_timeout": True,
426-
"failure_stage": "stage3",
427-
}
428-
)
429-
merged_result.metrics["stage3_passed"] = 0.0
430-
merged_result.metrics["timeout"] = True
431-
return merged_result
432-
except Exception as e:
433-
logger.error(f"Error in stage 3 evaluation: {str(e)}")
434-
# Capture stage 3 failure, but keep previous results
435-
merged_result.artifacts.update(
436-
{
437-
"stage3_stderr": str(e),
438-
"stage3_traceback": traceback.format_exc(),
439-
"failure_stage": "stage3",
440-
}
441-
)
442-
merged_result.metrics["stage3_passed"] = 0.0
443-
return merged_result
400+
stage3_result = await asyncio.wait_for(run_stage3(), timeout=self.config.timeout)
401+
stage3_eval_result = self._process_evaluation_result(stage3_result)
444402

445403
# Merge stage 3 results
446404
for name, value in stage3_eval_result.metrics.items():
@@ -453,14 +411,8 @@ async def run_stage3():
453411

454412
except Exception as e:
455413
logger.error(f"Error in cascade evaluation: {str(e)}")
456-
return EvaluationResult(
457-
metrics={"error": 0.0},
458-
artifacts={
459-
"stderr": str(e),
460-
"traceback": traceback.format_exc(),
461-
"failure_stage": "cascade_setup",
462-
},
463-
)
414+
# Re-raise the exception to allow retry handling at higher level
415+
raise
464416

465417
async def _llm_evaluate(self, program_code: str, program_id: str = "") -> Dict[str, float]:
466418
"""

openevolve/utils/async_utils.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,28 +33,24 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any:
3333

3434

3535
async def run_with_timeout(
36-
coro: Callable,
37-
timeout: float,
38-
*args: Any,
39-
timeout_error_value: Any = None,
40-
**kwargs: Any
36+
coro: Callable, timeout: float, *args: Any, timeout_error_value: Any = None, **kwargs: Any
4137
) -> Any:
4238
"""
4339
Run a coroutine with a timeout, returning a default value on timeout
44-
40+
4541
Args:
4642
coro: Coroutine function to run
4743
timeout: Timeout in seconds
4844
*args: Arguments to pass to the coroutine
4945
timeout_error_value: Value to return on timeout (default: {"error": 0.0, "timeout": True})
5046
**kwargs: Keyword arguments to pass to the coroutine
51-
47+
5248
Returns:
5349
Result of the coroutine or timeout_error_value on timeout
5450
"""
5551
if timeout_error_value is None:
5652
timeout_error_value = {"error": 0.0, "timeout": True}
57-
53+
5854
try:
5955
return await asyncio.wait_for(coro(*args, **kwargs), timeout=timeout)
6056
except asyncio.TimeoutError:
@@ -63,28 +59,24 @@ async def run_with_timeout(
6359

6460

6561
async def run_sync_with_timeout(
66-
func: Callable,
67-
timeout: float,
68-
*args: Any,
69-
timeout_error_value: Any = None,
70-
**kwargs: Any
62+
func: Callable, timeout: float, *args: Any, timeout_error_value: Any = None, **kwargs: Any
7163
) -> Any:
7264
"""
7365
Run a synchronous function in an executor with a timeout
74-
66+
7567
Args:
7668
func: Synchronous function to run
7769
timeout: Timeout in seconds
7870
*args: Arguments to pass to the function
7971
timeout_error_value: Value to return on timeout (default: {"error": 0.0, "timeout": True})
8072
**kwargs: Keyword arguments to pass to the function
81-
73+
8274
Returns:
8375
Result of the function or timeout_error_value on timeout
8476
"""
8577
if timeout_error_value is None:
8678
timeout_error_value = {"error": 0.0, "timeout": True}
87-
79+
8880
try:
8981
loop = asyncio.get_event_loop()
9082
task = loop.run_in_executor(None, functools.partial(func, *args, **kwargs))

0 commit comments

Comments
 (0)