Skip to content

Commit 23b0495

Browse files
committed
Update evaluator.py
1 parent 089518e commit 23b0495

File tree

1 file changed

+57
-92
lines changed

1 file changed

+57
-92
lines changed

examples/mlx_metal_kernel_opt/evaluator.py

Lines changed: 57 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,28 @@
11
"""
2-
Robust Qwen3 Custom GQA Attention Evaluator with Comprehensive Metal Kernel Error Handling
2+
Thread-Safe Robust Qwen3 Custom GQA Attention Evaluator
33
4-
This evaluator provides bulletproof protection against Metal kernel failures that terminate evolution:
4+
This evaluator provides bulletproof protection against Metal kernel failures without using signals:
55
6-
🛡️ PROTECTION FEATURES:
7-
1. Signal-based timeout handling for hanging Metal kernels
8-
2. Comprehensive C++ exception catching with try-catch blocks
9-
3. Process isolation for dangerous Metal kernel execution
10-
4. Retry mechanisms with exponential backoff
11-
5. Graceful fallback to standard attention on failures
12-
6. Detailed error classification and recovery strategies
6+
🛡️ THREAD-SAFE PROTECTION:
7+
1. No signal-based timeouts (works in worker threads)
8+
2. Comprehensive C++ exception catching
9+
3. Retry mechanisms with exponential backoff
10+
4. Graceful fallback to standard attention on failures
11+
5. Detailed error classification and recovery
1312
1413
🔧 EVOLUTION SAFETY:
1514
- Never terminates the evolution process due to kernel errors
15+
- Works perfectly in OpenEvolve's worker threads
1616
- Provides meaningful feedback on kernel failure types
17-
- Maintains evaluation progress even with problematic kernels
1817
- Statistical tracking of Metal kernel error patterns
19-
20-
Evolution Target:
21-
- Custom GQA implementation using MLX primitives
22-
- 40:8 query-to-KV head pattern optimization
23-
- Safe evolution despite Metal kernel instability
2418
"""
2519

2620
import os
2721
import sys
2822
import json
2923
import time
3024
import traceback
31-
import signal
32-
import subprocess
33-
import tempfile
25+
import threading
3426
from typing import Dict, List, Tuple, Any, Optional
3527
import numpy as np
3628

@@ -49,21 +41,20 @@ class MetalKernelError(Exception):
4941
pass
5042

5143

52-
class TimeoutError(Exception):
53-
"""Custom timeout exception for compatibility"""
44+
class ThreadSafeTimeoutError(Exception):
45+
"""Thread-safe timeout exception"""
5446
pass
5547

5648

57-
class RobustCustomGQAEvaluator:
58-
"""Bulletproof evaluator that never crashes from Metal kernel errors"""
49+
class ThreadSafeRobustEvaluator:
50+
"""Thread-safe bulletproof evaluator that never crashes from Metal kernel errors"""
5951

6052
def __init__(self):
6153
self.model_path = "mlx-community/Qwen3-0.6B-bf16"
6254

63-
# Error handling configuration
64-
self.metal_kernel_timeout = 45 # 45 second timeout for Metal operations
55+
# Error handling configuration (no signal-based timeouts)
56+
self.metal_kernel_timeout = 45 # Reference only, no actual timeout enforcement
6557
self.max_retry_attempts = 2
66-
self.use_process_isolation = False # Disable for now, causes import issues
6758

6859
# Error tracking
6960
self.metal_errors_caught = 0
@@ -77,27 +68,26 @@ def __init__(self):
7768
# Use comprehensive benchmark suite for consistency
7869
self.benchmark_suite = Qwen3BenchmarkSuite(self.model_path)
7970

80-
print("🛡️ Initialized Robust Custom GQA Evaluator")
71+
print("🛡️ Initialized Thread-Safe Robust Custom GQA Evaluator")
8172
print(f"📱 Model: {self.model_path}")
82-
print(f"⏱️ Metal kernel timeout: {self.metal_kernel_timeout}s")
8373
print(f"🔁 Max retry attempts: {self.max_retry_attempts}")
84-
print(f"🚫 Process isolation: {self.use_process_isolation}")
74+
print(f"🧵 Thread-safe: No signal dependencies")
8575

8676
def evaluate(self, program_text: str) -> Dict[str, Any]:
8777
"""
88-
Bulletproof evaluation that never crashes:
78+
Thread-safe bulletproof evaluation that never crashes:
8979
1. Safe extraction with syntax validation
9080
2. Protected baseline measurement
91-
3. Isolated correctness testing with timeouts
81+
3. Isolated correctness testing
9282
4. Robust benchmarking with retries
9383
5. Comprehensive Metal kernel error recovery
9484
"""
9585

9686
print("\n" + "=" * 100)
97-
print("🛡️ BULLETPROOF CUSTOM GQA ATTENTION EVALUATION")
87+
print("🛡️ THREAD-SAFE BULLETPROOF CUSTOM GQA ATTENTION EVALUATION")
9888
print("=" * 100)
9989
print("✅ Comprehensive Metal kernel error protection")
100-
print("✅ Signal-based timeout handling")
90+
print("✅ Thread-safe operation (no signal dependencies)")
10191
print("✅ Multi-layer exception catching")
10292
print("✅ Automatic retry with exponential backoff")
10393
print("✅ Never crashes the evolution process")
@@ -111,7 +101,7 @@ def evaluate(self, program_text: str) -> Dict[str, Any]:
111101

112102
# Step 1: Ultra-safe extraction
113103
print("\n🔧 STEP 1: Ultra-Safe Custom Attention Class Extraction")
114-
extraction_result = self._bulletproof_extract_custom_attention_class(program_text)
104+
extraction_result = self._thread_safe_extract_custom_attention_class(program_text)
115105
if not extraction_result["success"]:
116106
return self._create_failure_result(f"Extraction failed: {extraction_result['error']}")
117107

@@ -123,9 +113,9 @@ def evaluate(self, program_text: str) -> Dict[str, Any]:
123113
if not baseline_results:
124114
return self._create_failure_result("Failed to measure baseline performance safely")
125115

126-
# Step 3: Bulletproof correctness testing
127-
print("\n🔍 STEP 3: Bulletproof Custom Attention Correctness Testing")
128-
correctness_result = self._bulletproof_correctness_test(custom_attention_class)
116+
# Step 3: Thread-safe correctness testing
117+
print("\n🔍 STEP 3: Thread-Safe Custom Attention Correctness Testing")
118+
correctness_result = self._thread_safe_correctness_test(custom_attention_class)
129119
if not correctness_result["success"]:
130120
return self._create_failure_result(f"Correctness test failed: {correctness_result['error']}")
131121

@@ -184,10 +174,10 @@ def evaluate(self, program_text: str) -> Dict[str, Any]:
184174
traceback.print_exc()
185175
return self._create_failure_result(error_msg)
186176

187-
def _bulletproof_extract_custom_attention_class(self, program_text: str) -> Dict[str, Any]:
188-
"""Ultra-safe extraction with comprehensive error handling"""
177+
def _thread_safe_extract_custom_attention_class(self, program_text: str) -> Dict[str, Any]:
178+
"""Thread-safe extraction with comprehensive error handling"""
189179
try:
190-
print(" 🔍 Ultra-safe program analysis...")
180+
print(" 🔍 Thread-safe program analysis...")
191181

192182
# Handle file paths vs direct text
193183
if (
@@ -217,15 +207,14 @@ def _bulletproof_extract_custom_attention_class(self, program_text: str) -> Dict
217207
return {"success": False, "error": f"Compilation error: {e}"}
218208

219209
# Create bulletproof execution environment
220-
exec_globals = self._create_bulletproof_execution_environment()
210+
exec_globals = self._create_safe_execution_environment()
221211

222-
# Execute program with comprehensive protection
212+
# Execute program with comprehensive protection (no timeouts)
223213
print(" ⚙️ Executing program with maximum protection...")
224214
try:
225-
# Use timeout protection even for program execution
226-
success, result = self._execute_with_metal_protection(
227-
lambda: exec(actual_program_text, exec_globals),
228-
timeout=30 # 30 second timeout for program execution
215+
# Use thread-safe execution
216+
success, result = self._thread_safe_execute_with_protection(
217+
lambda: exec(actual_program_text, exec_globals)
229218
)
230219

231220
if not success:
@@ -258,7 +247,7 @@ def _bulletproof_extract_custom_attention_class(self, program_text: str) -> Dict
258247
except Exception as e:
259248
return {"success": False, "error": f"Extraction failed with exception: {str(e)}"}
260249

261-
def _create_bulletproof_execution_environment(self) -> Dict[str, Any]:
250+
def _create_safe_execution_environment(self) -> Dict[str, Any]:
262251
"""Create ultra-safe execution environment"""
263252
import math
264253
import numpy as np
@@ -309,10 +298,9 @@ def _protected_measure_baseline_performance(self) -> Optional[List[BenchmarkResu
309298
print(f" [{i}/{len(baseline_configs)}] Protected baseline: {config.name}")
310299

311300
try:
312-
# Run with Metal kernel protection
313-
success, result = self._execute_with_metal_protection(
314-
lambda: self.benchmark_suite.run_single_benchmark(config),
315-
timeout=90 # 90 second timeout per benchmark
301+
# Run with thread-safe Metal kernel protection
302+
success, result = self._thread_safe_execute_with_protection(
303+
lambda: self.benchmark_suite.run_single_benchmark(config)
316304
)
317305

318306
if success and result:
@@ -344,9 +332,9 @@ def _protected_measure_baseline_performance(self) -> Optional[List[BenchmarkResu
344332
print(f" ❌ Protected baseline measurement failed: {e}")
345333
return None
346334

347-
def _bulletproof_correctness_test(self, custom_attention_class: Any) -> Dict[str, Any]:
348-
"""Bulletproof correctness testing with maximum protection"""
349-
print(" 🔍 Running bulletproof correctness testing...")
335+
def _thread_safe_correctness_test(self, custom_attention_class: Any) -> Dict[str, Any]:
336+
"""Thread-safe correctness testing with maximum protection"""
337+
print(" 🔍 Running thread-safe correctness testing...")
350338

351339
try:
352340
# Create safe test configuration
@@ -375,17 +363,16 @@ class MockArgs:
375363
local_timeout_errors = 0
376364

377365
for B, L, D in test_cases:
378-
print(f" 🧪 Testing sequence length {L} with maximum protection...")
366+
print(f" 🧪 Testing sequence length {L} with thread-safe protection...")
379367

380368
try:
381369
# Create test inputs
382370
x = mx.random.normal((B, L, D))
383371
mask = "causal"
384372

385-
# Test with bulletproof execution
386-
success, result = self._execute_with_metal_protection(
387-
lambda: self._test_single_sequence_safely(custom_attention_class, args, x, mask),
388-
timeout=self.metal_kernel_timeout
373+
# Test with thread-safe execution
374+
success, result = self._thread_safe_execute_with_protection(
375+
lambda: self._test_single_sequence_safely(custom_attention_class, args, x, mask)
389376
)
390377

391378
if success:
@@ -432,7 +419,7 @@ class MockArgs:
432419
}
433420

434421
except Exception as e:
435-
print(f" ❌ Bulletproof correctness testing failed: {e}")
422+
print(f" ❌ Thread-safe correctness testing failed: {e}")
436423
return {"success": False, "error": str(e)}
437424

438425
def _test_single_sequence_safely(self, custom_attention_class: Any, args: Any, x: Any, mask: Any) -> float:
@@ -518,9 +505,8 @@ def _armored_benchmark_custom_attention(self, custom_attention_class: Any) -> Di
518505

519506
try:
520507
# Run with comprehensive protection
521-
success, result = self._execute_with_metal_protection(
522-
lambda: self.benchmark_suite.run_single_benchmark(config),
523-
timeout=120 # 2 minute timeout per benchmark
508+
success, result = self._thread_safe_execute_with_protection(
509+
lambda: self.benchmark_suite.run_single_benchmark(config)
524510
)
525511

526512
if success and result:
@@ -565,28 +551,13 @@ def _armored_benchmark_custom_attention(self, custom_attention_class: Any) -> Di
565551

566552
return {"success": False, "error": "All armored attempts exhausted"}
567553

568-
def _execute_with_metal_protection(self, func, timeout: int) -> Tuple[bool, Any]:
569-
"""Execute function with comprehensive Metal kernel protection"""
570-
571-
# Timeout handler using signals (Unix systems)
572-
def timeout_handler(signum, frame):
573-
raise TimeoutError(f"Operation timed out after {timeout} seconds")
574-
575-
# Set up timeout protection if available
576-
old_handler = None
577-
if hasattr(signal, 'SIGALRM'):
578-
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
579-
signal.alarm(timeout)
580-
554+
def _thread_safe_execute_with_protection(self, func) -> Tuple[bool, Any]:
555+
"""Thread-safe execution with comprehensive Metal kernel protection (no signals)"""
581556
try:
582557
# Execute the function with comprehensive error catching
583558
result = func()
584559
return True, result
585560

586-
except TimeoutError as e:
587-
self.timeout_errors_caught += 1
588-
return False, f"Timeout error: {str(e)}"
589-
590561
except Exception as e:
591562
error_msg = str(e)
592563

@@ -597,12 +568,6 @@ def timeout_handler(signum, frame):
597568
return False, f"Metal kernel error: {error_msg}"
598569
else:
599570
return False, f"Execution error: {error_msg}"
600-
601-
finally:
602-
# Clean up timeout signal
603-
if hasattr(signal, 'SIGALRM') and old_handler is not None:
604-
signal.alarm(0)
605-
signal.signal(signal.SIGALRM, old_handler)
606571

607572
def _protected_apply_custom_attention_hook(self, custom_attention_class: Any) -> Dict[str, Any]:
608573
"""Protected application of custom attention hook"""
@@ -882,7 +847,7 @@ def _generate_summary(self, performance_analysis: Dict[str, Any], correctness: f
882847
def _print_evaluation_results(self, result: Dict[str, Any]):
883848
"""Print comprehensive evaluation results"""
884849
print(f"\n{'='*100}")
885-
print(f"{'🎯 BULLETPROOF EVALUATION RESULTS':^100}")
850+
print(f"{'🎯 THREAD-SAFE EVALUATION RESULTS':^100}")
886851
print(f"{'='*100}")
887852

888853
if result["success"]:
@@ -946,13 +911,13 @@ def _result_to_dict(self, result: BenchmarkResult) -> Dict:
946911

947912
def evaluate(program_text: str) -> Dict[str, Any]:
948913
"""Main evaluation function called by OpenEvolve"""
949-
evaluator = RobustCustomGQAEvaluator()
914+
evaluator = ThreadSafeRobustEvaluator()
950915
return evaluator.evaluate(program_text)
951916

952917

953-
def test_robust_evaluator():
954-
"""Test the bulletproof evaluator"""
955-
print("🧪 Testing Bulletproof Custom GQA Evaluator")
918+
def test_thread_safe_evaluator():
919+
"""Test the thread-safe evaluator"""
920+
print("🧪 Testing Thread-Safe Robust Custom GQA Evaluator")
956921
print("=" * 80)
957922

958923
initial_program_path = os.path.join(os.path.dirname(__file__), "initial_program.py")
@@ -965,7 +930,7 @@ def test_robust_evaluator():
965930
result = evaluate(initial_program_path)
966931

967932
print(f"\n{'='*80}")
968-
print(f"🔬 BULLETPROOF EVALUATOR TEST RESULTS")
933+
print(f"🔬 THREAD-SAFE EVALUATOR TEST RESULTS")
969934
print(f"{'='*80}")
970935
print(f"Success: {result['success']}")
971936
print(f"Final Score: {result.get('final_score', 'N/A')}")
@@ -980,4 +945,4 @@ def test_robust_evaluator():
980945

981946

982947
if __name__ == "__main__":
983-
test_robust_evaluator()
948+
test_thread_safe_evaluator()

0 commit comments

Comments
 (0)