11"""
2- Robust Qwen3 Custom GQA Attention Evaluator with Comprehensive Metal Kernel Error Handling
2+ Thread-Safe Robust Qwen3 Custom GQA Attention Evaluator
33
4- This evaluator provides bulletproof protection against Metal kernel failures that terminate evolution :
4+ This evaluator provides bulletproof protection against Metal kernel failures without using signals :
55
6- 🛡️ PROTECTION FEATURES:
7- 1. Signal-based timeout handling for hanging Metal kernels
8- 2. Comprehensive C++ exception catching with try-catch blocks
9- 3. Process isolation for dangerous Metal kernel execution
10- 4. Retry mechanisms with exponential backoff
11- 5. Graceful fallback to standard attention on failures
12- 6. Detailed error classification and recovery strategies
6+ 🛡️ THREAD-SAFE PROTECTION:
7+ 1. No signal-based timeouts (works in worker threads)
8+ 2. Comprehensive C++ exception catching
9+ 3. Retry mechanisms with exponential backoff
10+ 4. Graceful fallback to standard attention on failures
11+ 5. Detailed error classification and recovery
1312
1413🔧 EVOLUTION SAFETY:
1514- Never terminates the evolution process due to kernel errors
15+ - Works perfectly in OpenEvolve's worker threads
1616- Provides meaningful feedback on kernel failure types
17- - Maintains evaluation progress even with problematic kernels
1817- Statistical tracking of Metal kernel error patterns
19-
20- Evolution Target:
21- - Custom GQA implementation using MLX primitives
22- - 40:8 query-to-KV head pattern optimization
23- - Safe evolution despite Metal kernel instability
2418"""
2519
2620import os
2721import sys
2822import json
2923import time
3024import traceback
31- import signal
32- import subprocess
33- import tempfile
25+ import threading
3426from typing import Dict , List , Tuple , Any , Optional
3527import numpy as np
3628
@@ -49,21 +41,20 @@ class MetalKernelError(Exception):
4941 pass
5042
5143
52- class TimeoutError (Exception ):
53- """Custom timeout exception for compatibility """
44+ class ThreadSafeTimeoutError (Exception ):
45+ """Thread-safe timeout exception"""
5446 pass
5547
5648
57- class RobustCustomGQAEvaluator :
58- """Bulletproof evaluator that never crashes from Metal kernel errors"""
49+ class ThreadSafeRobustEvaluator :
50+ """Thread-safe bulletproof evaluator that never crashes from Metal kernel errors"""
5951
6052 def __init__ (self ):
6153 self .model_path = "mlx-community/Qwen3-0.6B-bf16"
6254
63- # Error handling configuration
64- self .metal_kernel_timeout = 45 # 45 second timeout for Metal operations
55+ # Error handling configuration (no signal-based timeouts)
56+ self .metal_kernel_timeout = 45 # Reference only, no actual timeout enforcement
6557 self .max_retry_attempts = 2
66- self .use_process_isolation = False # Disable for now, causes import issues
6758
6859 # Error tracking
6960 self .metal_errors_caught = 0
@@ -77,27 +68,26 @@ def __init__(self):
7768 # Use comprehensive benchmark suite for consistency
7869 self .benchmark_suite = Qwen3BenchmarkSuite (self .model_path )
7970
80- print ("🛡️ Initialized Robust Custom GQA Evaluator" )
71+ print ("🛡️ Initialized Thread-Safe Robust Custom GQA Evaluator" )
8172 print (f"📱 Model: { self .model_path } " )
82- print (f"⏱️ Metal kernel timeout: { self .metal_kernel_timeout } s" )
8373 print (f"🔁 Max retry attempts: { self .max_retry_attempts } " )
84- print (f"🚫 Process isolation: { self . use_process_isolation } " )
74+ print (f"🧵 Thread-safe: No signal dependencies " )
8575
8676 def evaluate (self , program_text : str ) -> Dict [str , Any ]:
8777 """
88- Bulletproof evaluation that never crashes:
78+ Thread-safe bulletproof evaluation that never crashes:
8979 1. Safe extraction with syntax validation
9080 2. Protected baseline measurement
91- 3. Isolated correctness testing with timeouts
81+ 3. Isolated correctness testing
9282 4. Robust benchmarking with retries
9383 5. Comprehensive Metal kernel error recovery
9484 """
9585
9686 print ("\n " + "=" * 100 )
97- print ("🛡️ BULLETPROOF CUSTOM GQA ATTENTION EVALUATION" )
87+ print ("🛡️ THREAD-SAFE BULLETPROOF CUSTOM GQA ATTENTION EVALUATION" )
9888 print ("=" * 100 )
9989 print ("✅ Comprehensive Metal kernel error protection" )
100- print ("✅ Signal-based timeout handling " )
90+ print ("✅ Thread-safe operation (no signal dependencies) " )
10191 print ("✅ Multi-layer exception catching" )
10292 print ("✅ Automatic retry with exponential backoff" )
10393 print ("✅ Never crashes the evolution process" )
@@ -111,7 +101,7 @@ def evaluate(self, program_text: str) -> Dict[str, Any]:
111101
112102 # Step 1: Ultra-safe extraction
113103 print ("\n 🔧 STEP 1: Ultra-Safe Custom Attention Class Extraction" )
114- extraction_result = self ._bulletproof_extract_custom_attention_class (program_text )
104+ extraction_result = self ._thread_safe_extract_custom_attention_class (program_text )
115105 if not extraction_result ["success" ]:
116106 return self ._create_failure_result (f"Extraction failed: { extraction_result ['error' ]} " )
117107
@@ -123,9 +113,9 @@ def evaluate(self, program_text: str) -> Dict[str, Any]:
123113 if not baseline_results :
124114 return self ._create_failure_result ("Failed to measure baseline performance safely" )
125115
126- # Step 3: Bulletproof correctness testing
127- print ("\n 🔍 STEP 3: Bulletproof Custom Attention Correctness Testing" )
128- correctness_result = self ._bulletproof_correctness_test (custom_attention_class )
116+ # Step 3: Thread-safe correctness testing
117+ print ("\n 🔍 STEP 3: Thread-Safe Custom Attention Correctness Testing" )
118+ correctness_result = self ._thread_safe_correctness_test (custom_attention_class )
129119 if not correctness_result ["success" ]:
130120 return self ._create_failure_result (f"Correctness test failed: { correctness_result ['error' ]} " )
131121
@@ -184,10 +174,10 @@ def evaluate(self, program_text: str) -> Dict[str, Any]:
184174 traceback .print_exc ()
185175 return self ._create_failure_result (error_msg )
186176
187- def _bulletproof_extract_custom_attention_class (self , program_text : str ) -> Dict [str , Any ]:
188- """Ultra -safe extraction with comprehensive error handling"""
177+ def _thread_safe_extract_custom_attention_class (self , program_text : str ) -> Dict [str , Any ]:
178+ """Thread -safe extraction with comprehensive error handling"""
189179 try :
190- print (" 🔍 Ultra -safe program analysis..." )
180+ print (" 🔍 Thread -safe program analysis..." )
191181
192182 # Handle file paths vs direct text
193183 if (
@@ -217,15 +207,14 @@ def _bulletproof_extract_custom_attention_class(self, program_text: str) -> Dict
217207 return {"success" : False , "error" : f"Compilation error: { e } " }
218208
219209 # Create bulletproof execution environment
220- exec_globals = self ._create_bulletproof_execution_environment ()
210+ exec_globals = self ._create_safe_execution_environment ()
221211
222- # Execute program with comprehensive protection
212+ # Execute program with comprehensive protection (no timeouts)
223213 print (" ⚙️ Executing program with maximum protection..." )
224214 try :
225- # Use timeout protection even for program execution
226- success , result = self ._execute_with_metal_protection (
227- lambda : exec (actual_program_text , exec_globals ),
228- timeout = 30 # 30 second timeout for program execution
215+ # Use thread-safe execution
216+ success , result = self ._thread_safe_execute_with_protection (
217+ lambda : exec (actual_program_text , exec_globals )
229218 )
230219
231220 if not success :
@@ -258,7 +247,7 @@ def _bulletproof_extract_custom_attention_class(self, program_text: str) -> Dict
258247 except Exception as e :
259248 return {"success" : False , "error" : f"Extraction failed with exception: { str (e )} " }
260249
261- def _create_bulletproof_execution_environment (self ) -> Dict [str , Any ]:
250+ def _create_safe_execution_environment (self ) -> Dict [str , Any ]:
262251 """Create ultra-safe execution environment"""
263252 import math
264253 import numpy as np
@@ -309,10 +298,9 @@ def _protected_measure_baseline_performance(self) -> Optional[List[BenchmarkResu
309298 print (f" [{ i } /{ len (baseline_configs )} ] Protected baseline: { config .name } " )
310299
311300 try :
312- # Run with Metal kernel protection
313- success , result = self ._execute_with_metal_protection (
314- lambda : self .benchmark_suite .run_single_benchmark (config ),
315- timeout = 90 # 90 second timeout per benchmark
301+ # Run with thread-safe Metal kernel protection
302+ success , result = self ._thread_safe_execute_with_protection (
303+ lambda : self .benchmark_suite .run_single_benchmark (config )
316304 )
317305
318306 if success and result :
@@ -344,9 +332,9 @@ def _protected_measure_baseline_performance(self) -> Optional[List[BenchmarkResu
344332 print (f" ❌ Protected baseline measurement failed: { e } " )
345333 return None
346334
347- def _bulletproof_correctness_test (self , custom_attention_class : Any ) -> Dict [str , Any ]:
348- """Bulletproof correctness testing with maximum protection"""
349- print (" 🔍 Running bulletproof correctness testing..." )
335+ def _thread_safe_correctness_test (self , custom_attention_class : Any ) -> Dict [str , Any ]:
336+ """Thread-safe correctness testing with maximum protection"""
337+ print (" 🔍 Running thread-safe correctness testing..." )
350338
351339 try :
352340 # Create safe test configuration
@@ -375,17 +363,16 @@ class MockArgs:
375363 local_timeout_errors = 0
376364
377365 for B , L , D in test_cases :
378- print (f" 🧪 Testing sequence length { L } with maximum protection..." )
366+ print (f" 🧪 Testing sequence length { L } with thread-safe protection..." )
379367
380368 try :
381369 # Create test inputs
382370 x = mx .random .normal ((B , L , D ))
383371 mask = "causal"
384372
385- # Test with bulletproof execution
386- success , result = self ._execute_with_metal_protection (
387- lambda : self ._test_single_sequence_safely (custom_attention_class , args , x , mask ),
388- timeout = self .metal_kernel_timeout
373+ # Test with thread-safe execution
374+ success , result = self ._thread_safe_execute_with_protection (
375+ lambda : self ._test_single_sequence_safely (custom_attention_class , args , x , mask )
389376 )
390377
391378 if success :
@@ -432,7 +419,7 @@ class MockArgs:
432419 }
433420
434421 except Exception as e :
435- print (f" ❌ Bulletproof correctness testing failed: { e } " )
422+ print (f" ❌ Thread-safe correctness testing failed: { e } " )
436423 return {"success" : False , "error" : str (e )}
437424
438425 def _test_single_sequence_safely (self , custom_attention_class : Any , args : Any , x : Any , mask : Any ) -> float :
@@ -518,9 +505,8 @@ def _armored_benchmark_custom_attention(self, custom_attention_class: Any) -> Di
518505
519506 try :
520507 # Run with comprehensive protection
521- success , result = self ._execute_with_metal_protection (
522- lambda : self .benchmark_suite .run_single_benchmark (config ),
523- timeout = 120 # 2 minute timeout per benchmark
508+ success , result = self ._thread_safe_execute_with_protection (
509+ lambda : self .benchmark_suite .run_single_benchmark (config )
524510 )
525511
526512 if success and result :
@@ -565,28 +551,13 @@ def _armored_benchmark_custom_attention(self, custom_attention_class: Any) -> Di
565551
566552 return {"success" : False , "error" : "All armored attempts exhausted" }
567553
568- def _execute_with_metal_protection (self , func , timeout : int ) -> Tuple [bool , Any ]:
569- """Execute function with comprehensive Metal kernel protection"""
570-
571- # Timeout handler using signals (Unix systems)
572- def timeout_handler (signum , frame ):
573- raise TimeoutError (f"Operation timed out after { timeout } seconds" )
574-
575- # Set up timeout protection if available
576- old_handler = None
577- if hasattr (signal , 'SIGALRM' ):
578- old_handler = signal .signal (signal .SIGALRM , timeout_handler )
579- signal .alarm (timeout )
580-
554+ def _thread_safe_execute_with_protection (self , func ) -> Tuple [bool , Any ]:
555+ """Thread-safe execution with comprehensive Metal kernel protection (no signals)"""
581556 try :
582557 # Execute the function with comprehensive error catching
583558 result = func ()
584559 return True , result
585560
586- except TimeoutError as e :
587- self .timeout_errors_caught += 1
588- return False , f"Timeout error: { str (e )} "
589-
590561 except Exception as e :
591562 error_msg = str (e )
592563
@@ -597,12 +568,6 @@ def timeout_handler(signum, frame):
597568 return False , f"Metal kernel error: { error_msg } "
598569 else :
599570 return False , f"Execution error: { error_msg } "
600-
601- finally :
602- # Clean up timeout signal
603- if hasattr (signal , 'SIGALRM' ) and old_handler is not None :
604- signal .alarm (0 )
605- signal .signal (signal .SIGALRM , old_handler )
606571
607572 def _protected_apply_custom_attention_hook (self , custom_attention_class : Any ) -> Dict [str , Any ]:
608573 """Protected application of custom attention hook"""
@@ -882,7 +847,7 @@ def _generate_summary(self, performance_analysis: Dict[str, Any], correctness: f
882847 def _print_evaluation_results (self , result : Dict [str , Any ]):
883848 """Print comprehensive evaluation results"""
884849 print (f"\n { '=' * 100 } " )
885- print (f"{ '🎯 BULLETPROOF EVALUATION RESULTS' :^100} " )
850+ print (f"{ '🎯 THREAD-SAFE EVALUATION RESULTS' :^100} " )
886851 print (f"{ '=' * 100 } " )
887852
888853 if result ["success" ]:
@@ -946,13 +911,13 @@ def _result_to_dict(self, result: BenchmarkResult) -> Dict:
946911
947912def evaluate (program_text : str ) -> Dict [str , Any ]:
948913 """Main evaluation function called by OpenEvolve"""
949- evaluator = RobustCustomGQAEvaluator ()
914+ evaluator = ThreadSafeRobustEvaluator ()
950915 return evaluator .evaluate (program_text )
951916
952917
953- def test_robust_evaluator ():
954- """Test the bulletproof evaluator"""
955- print ("🧪 Testing Bulletproof Custom GQA Evaluator" )
918+ def test_thread_safe_evaluator ():
919+ """Test the thread-safe evaluator"""
920+ print ("🧪 Testing Thread-Safe Robust Custom GQA Evaluator" )
956921 print ("=" * 80 )
957922
958923 initial_program_path = os .path .join (os .path .dirname (__file__ ), "initial_program.py" )
@@ -965,7 +930,7 @@ def test_robust_evaluator():
965930 result = evaluate (initial_program_path )
966931
967932 print (f"\n { '=' * 80 } " )
968- print (f"🔬 BULLETPROOF EVALUATOR TEST RESULTS" )
933+ print (f"🔬 THREAD-SAFE EVALUATOR TEST RESULTS" )
969934 print (f"{ '=' * 80 } " )
970935 print (f"Success: { result ['success' ]} " )
971936 print (f"Final Score: { result .get ('final_score' , 'N/A' )} " )
@@ -980,4 +945,4 @@ def test_robust_evaluator():
980945
981946
982947if __name__ == "__main__" :
983- test_robust_evaluator ()
948+ test_thread_safe_evaluator ()
0 commit comments