11"""
2- Two-Stage Evaluator for MLX Block Diagonal Attention Optimization
2+ Robust Two-Stage Evaluator for MLX Block Diagonal Attention Optimization
33
44STAGE 1: Correctness & Compatibility Gate
55- Ensures evolved programs produce correct outputs
2020import math
2121import time
2222import traceback
23- from typing import Dict , List , Tuple
23+ from typing import Dict , List , Tuple , Union
2424import gc
2525
2626import mlx .core as mx
3030from spda_benchmark import prepare_inputs , mlx_ref_attn , mlx_fused_attn , do_attention , bench
3131
3232
33+ def safe_format_percentage (value , fallback = "N/A%" ):
34+ """
35+ Safely format a value as a percentage.
36+
37+ Args:
38+ value: Value to format as percentage (should be between 0 and 1)
39+ fallback: Fallback string if formatting fails
40+
41+ Returns:
42+ Formatted percentage string
43+ """
44+ try :
45+ if isinstance (value , (int , float )) and not math .isnan (value ) and not math .isinf (value ):
46+ return f"{ value :.1%} "
47+ else :
48+ return fallback
49+ except (ValueError , TypeError ):
50+ return fallback
51+
52+
53+ def safe_format_number (value : Union [float , int , str ], format_spec : str = ".3f" , fallback : str = "N/A" ) -> str :
54+ """
55+ Safely format a number with fallback for non-numeric values.
56+ This prevents "Unknown format code 'f' for object of type 'str'" errors.
57+ """
58+ try :
59+ if isinstance (value , (int , float )) and not math .isnan (value ) and not math .isinf (value ):
60+ return f"{ value :{format_spec }} "
61+ elif value == float ("inf" ):
62+ return "∞"
63+ elif value == float ("-inf" ):
64+ return "-∞"
65+ elif isinstance (value , float ) and math .isnan (value ):
66+ return "NaN"
67+ else :
68+ return str (value ) if value is not None else fallback
69+ except (ValueError , TypeError ):
70+ return fallback
71+
72+
3373def create_stage1_test_configurations () -> List [Dict ]:
3474 """
3575 Stage 1: Comprehensive correctness tests based on spda_benchmark.
@@ -141,45 +181,59 @@ def create_stage2_performance_configurations() -> List[Dict]:
141181def compare_attention_outputs (output1 : mx .array , output2 : mx .array , tolerance : float = 1e-3 ) -> Dict [str , float ]:
142182 """
143183 Compare two attention outputs with appropriate tolerance.
144- Enhanced version from original evaluator .
184+ Enhanced version with robust error handling .
145185 """
146- # Ensure arrays are evaluated
147- output1 = mx .array (output1 )
148- output2 = mx .array (output2 )
149- mx .eval (output1 , output2 )
150-
151- # Calculate various similarity metrics
152- diff = output1 - output2
153- mse = float (mx .mean (diff ** 2 ))
154- mae = float (mx .mean (mx .abs (diff )))
155- max_diff = float (mx .max (mx .abs (diff )))
156-
157- # Relative error (normalized by output magnitude)
158- output1_norm = float (mx .sqrt (mx .mean (output1 ** 2 )))
159- relative_error = float (mx .sqrt (mx .mean (diff ** 2 ))) / max (output1_norm , 1e-8 )
160-
161- # Check MLX's allclose function
162- allclose_result = bool (mx .allclose (output1 , output2 , atol = tolerance , rtol = tolerance ))
163-
164- # Additional robust check: if MSE is extremely small, consider it a match
165- mse_perfect = mse < 1e-8
166-
167- # Final decision: either allclose passes OR MSE is extremely small
168- final_allclose = allclose_result or mse_perfect
169-
170- return {
171- "mse" : mse ,
172- "mae" : mae ,
173- "max_diff" : max_diff ,
174- "relative_error" : relative_error ,
175- "allclose" : final_allclose ,
176- "allclose_strict" : allclose_result ,
177- "mse_perfect" : mse_perfect ,
178- "tolerance_used" : tolerance ,
179- }
180-
181-
182- def evaluate_stage1_correctness (evolved_attention_fn , config : Dict ) -> Dict [str , float ]:
186+ try :
187+ # Ensure arrays are evaluated
188+ output1 = mx .array (output1 )
189+ output2 = mx .array (output2 )
190+ mx .eval (output1 , output2 )
191+
192+ # Calculate various similarity metrics
193+ diff = output1 - output2
194+ mse = float (mx .mean (diff ** 2 ))
195+ mae = float (mx .mean (mx .abs (diff )))
196+ max_diff = float (mx .max (mx .abs (diff )))
197+
198+ # Relative error (normalized by output magnitude)
199+ output1_norm = float (mx .sqrt (mx .mean (output1 ** 2 )))
200+ relative_error = float (mx .sqrt (mx .mean (diff ** 2 ))) / max (output1_norm , 1e-8 )
201+
202+ # Check MLX's allclose function
203+ allclose_result = bool (mx .allclose (output1 , output2 , atol = tolerance , rtol = tolerance ))
204+
205+ # Additional robust check: if MSE is extremely small, consider it a match
206+ mse_perfect = mse < 1e-8
207+
208+ # Final decision: either allclose passes OR MSE is extremely small
209+ final_allclose = allclose_result or mse_perfect
210+
211+ return {
212+ "mse" : mse ,
213+ "mae" : mae ,
214+ "max_diff" : max_diff ,
215+ "relative_error" : relative_error ,
216+ "allclose" : final_allclose ,
217+ "allclose_strict" : allclose_result ,
218+ "mse_perfect" : mse_perfect ,
219+ "tolerance_used" : tolerance ,
220+ }
221+ except Exception as e :
222+ # Fallback values if comparison fails
223+ return {
224+ "mse" : float ("inf" ),
225+ "mae" : float ("inf" ),
226+ "max_diff" : float ("inf" ),
227+ "relative_error" : float ("inf" ),
228+ "allclose" : False ,
229+ "allclose_strict" : False ,
230+ "mse_perfect" : False ,
231+ "tolerance_used" : tolerance ,
232+ "comparison_error" : str (e ),
233+ }
234+
235+
236+ def evaluate_stage1_correctness (evolved_attention_fn , config : Dict ) -> Dict [str , Union [bool , float , str ]]:
183237 """
184238 Stage 1: Test correctness with category-appropriate tolerances.
185239
@@ -244,7 +298,7 @@ def evaluate_stage1_correctness(evolved_attention_fn, config: Dict) -> Dict[str,
244298 # For shorter sequences, compute reference for comparison
245299 try :
246300 reference_output = mlx_ref_attn (q , k , v , scale = scale , mask = mask )
247- except Exception :
301+ except Exception as ref_error :
248302 # Reference failed, check structural validity only
249303 has_nan = bool (mx .any (mx .isnan (evolved_output )))
250304 has_inf = bool (mx .any (mx .isinf (evolved_output )))
@@ -258,7 +312,7 @@ def evaluate_stage1_correctness(evolved_attention_fn, config: Dict) -> Dict[str,
258312 "tolerance_used" : tolerance ,
259313 "category" : category ,
260314 "reference_computed" : False ,
261- "reference_error" : "Reference computation failed" ,
315+ "reference_error" : str ( ref_error ) ,
262316 }
263317
264318 # Compare outputs with category-appropriate tolerance
@@ -293,7 +347,7 @@ def evaluate_stage1_correctness(evolved_attention_fn, config: Dict) -> Dict[str,
293347 }
294348
295349
296- def benchmark_performance (evolved_fn , config : Dict , num_trials : int = 3 ) -> Dict [str , float ]:
350+ def benchmark_performance (evolved_fn , config : Dict , num_trials : int = 3 ) -> Dict [str , Union [ float , str ] ]:
297351 """
298352 Stage 2: Benchmark performance vs mx.fast.scaled_dot_product_attention.
299353 """
@@ -388,7 +442,7 @@ def benchmark_performance(evolved_fn, config: Dict, num_trials: int = 3) -> Dict
388442 return {"speedup" : 0.0 , "performance_score" : 0.0 , "error" : str (e )}
389443
390444
391- def evaluate_two_stage (program_path : str ) -> Dict [str , float ]:
445+ def evaluate_two_stage (program_path : str ) -> Dict [str , Union [ bool , float , str , int ] ]:
392446 """
393447 Two-stage evaluation: Correctness gate + Performance optimization.
394448 """
@@ -431,19 +485,25 @@ def evaluate_two_stage(program_path: str) -> Dict[str, float]:
431485 if result ["passed" ]:
432486 stage1_passed_count += 1
433487 mse_val = result .get ('mse' , 'N/A' )
434- if isinstance (mse_val , (int , float )) and not math .isnan (mse_val ) and not math .isinf (mse_val ):
435- mse_str = f"{ mse_val :.2e} "
436- else :
437- mse_str = str (mse_val )
488+ mse_str = safe_format_number (mse_val , ".2e" )
438489 print (f" ✅ PASSED: MSE={ mse_str } " )
439490 else :
440- print (f" ❌ FAILED: { result .get ('error' , 'Accuracy/structure issue' )} " )
491+ error_msg = result .get ('error' , 'Accuracy/structure issue' )
492+ print (f" ❌ FAILED: { error_msg } " )
441493
442- stage1_pass_rate = stage1_passed_count / len (stage1_configs )
494+ # Safe calculation of stage1_pass_rate to prevent division errors
495+ try :
496+ stage1_pass_rate = stage1_passed_count / len (stage1_configs ) if len (stage1_configs ) > 0 else 0.0
497+ except (TypeError , ZeroDivisionError ):
498+ stage1_pass_rate = 0.0
499+
443500 stage1_passed = stage1_pass_rate >= 0.9 # 90% pass rate required
444501
502+ # Safe formatting for stage1_pass_rate
503+ stage1_pass_rate_str = safe_format_percentage (stage1_pass_rate )
504+
445505 print (f"\n 📊 STAGE 1 Results:" )
446- print (f" Passed: { stage1_passed_count } /{ len (stage1_configs )} ({ stage1_pass_rate :.1% } )" )
506+ print (f" Passed: { stage1_passed_count } /{ len (stage1_configs )} ({ stage1_pass_rate_str } )" )
447507 print (f" Gate Status: { '✅ PASSED' if stage1_passed else '❌ FAILED' } " )
448508
449509 if not stage1_passed :
@@ -484,41 +544,44 @@ def evaluate_two_stage(program_path: str) -> Dict[str, float]:
484544 "weighted_score" : weighted_score ,
485545 })
486546
487- # Safe formatting for speedup
488- if isinstance (speedup , (int , float )) and not math .isnan (speedup ) and not math .isinf (speedup ):
489- speedup_str = f"{ speedup :.2f} "
490- elif speedup == float ("inf" ):
491- speedup_str = "∞"
492- else :
493- speedup_str = str (speedup )
494-
495- if isinstance (perf_score , (int , float )) and not math .isnan (perf_score ) and not math .isinf (perf_score ):
496- perf_str = f"{ perf_score :.3f} "
497- else :
498- perf_str = str (perf_score )
547+ # Safe formatting for speedup and performance score
548+ speedup_str = safe_format_number (speedup , ".2f" )
549+ perf_str = safe_format_number (perf_score , ".3f" )
499550
500551 print (f" 📊 Speedup: { speedup_str } x, Score: { perf_str } " )
501552
502- stage2_score = total_weighted_score / total_weight if total_weight > 0 else 0.0
503-
553+ # Safe calculation of stage2_score to prevent division errors
554+ try :
555+ stage2_score = total_weighted_score / total_weight if total_weight > 0 else 0.0
556+ except (TypeError , ZeroDivisionError ):
557+ stage2_score = 0.0
558+
504559 # Calculate overall score (Stage 1 gate + Stage 2 performance)
505560 overall_score = stage2_score # Since Stage 1 is just a gate
506561
507- # Detailed performance analysis
508- speedups = [r ["benchmark" ]["speedup" ] for r in stage2_results
509- if isinstance (r ["benchmark" ]["speedup" ], (int , float )) and
510- r ["benchmark" ]["speedup" ] != float ("inf" ) and
511- not math .isnan (r ["benchmark" ]["speedup" ])]
512- avg_speedup = np .mean (speedups ) if speedups else 0.0
513- max_speedup = max (speedups ) if speedups else 0.0
562+ # Detailed performance analysis with safe operations
563+ speedups = []
564+ for r in stage2_results :
565+ speedup_val = r ["benchmark" ]["speedup" ]
566+ if (isinstance (speedup_val , (int , float )) and
567+ speedup_val != float ("inf" ) and
568+ not math .isnan (speedup_val )):
569+ speedups .append (speedup_val )
570+
571+ try :
572+ avg_speedup = np .mean (speedups ) if speedups else 0.0
573+ max_speedup = max (speedups ) if speedups else 0.0
574+ except (TypeError , ValueError ):
575+ avg_speedup = 0.0
576+ max_speedup = 0.0
514577
515578 print (f"\n 📈 STAGE 2 Results:" )
516579
517- # Safe formatting
518- stage2_str = f" { stage2_score :.3f } " if isinstance (stage2_score , ( int , float )) else str ( stage2_score )
519- avg_speedup_str = f" { avg_speedup :.2f } " if isinstance (avg_speedup , ( int , float )) else str ( avg_speedup )
520- max_speedup_str = f" { max_speedup :.2f } " if isinstance (max_speedup , ( int , float )) else str ( max_speedup )
521- overall_str = f" { overall_score :.3f } " if isinstance (overall_score , ( int , float )) else str ( overall_score )
580+ # Safe formatting for final results
581+ stage2_str = safe_format_number (stage2_score , ".3f" )
582+ avg_speedup_str = safe_format_number (avg_speedup , ".2f" )
583+ max_speedup_str = safe_format_number (max_speedup , ".2f" )
584+ overall_str = safe_format_number (overall_score , ".3f" )
522585
523586 print (f" Performance Score: { stage2_str } " )
524587 print (f" Average Speedup: { avg_speedup_str } x" )
@@ -538,18 +601,32 @@ def evaluate_two_stage(program_path: str) -> Dict[str, float]:
538601 else :
539602 print (f" ❌ POOR: Need significant optimization" )
540603
604+ # Ensure all return values are safe numeric types
605+ try :
606+ safe_stage1_pass_rate = float (stage1_pass_rate ) if isinstance (stage1_pass_rate , (int , float )) else 0.0
607+ safe_stage2_score = float (stage2_score ) if isinstance (stage2_score , (int , float )) else 0.0
608+ safe_overall_score = float (overall_score ) if isinstance (overall_score , (int , float )) else 0.0
609+ safe_avg_speedup = float (avg_speedup ) if isinstance (avg_speedup , (int , float )) else 0.0
610+ safe_max_speedup = float (max_speedup ) if isinstance (max_speedup , (int , float )) else 0.0
611+ except (TypeError , ValueError ):
612+ safe_stage1_pass_rate = 0.0
613+ safe_stage2_score = 0.0
614+ safe_overall_score = 0.0
615+ safe_avg_speedup = 0.0
616+ safe_max_speedup = 0.0
617+
541618 return {
542619 # Gate results
543620 "stage1_passed" : stage1_passed ,
544- "stage1_pass_rate" : stage1_pass_rate ,
621+ "stage1_pass_rate" : safe_stage1_pass_rate ,
545622
546623 # Performance results
547- "stage2_score" : float ( stage2_score ) ,
548- "overall_score" : float ( overall_score ) ,
624+ "stage2_score" : safe_stage2_score ,
625+ "overall_score" : safe_overall_score ,
549626
550627 # Detailed metrics
551- "avg_speedup" : float ( avg_speedup ) ,
552- "max_speedup" : float ( max_speedup ) ,
628+ "avg_speedup" : safe_avg_speedup ,
629+ "max_speedup" : safe_max_speedup ,
553630 "num_stage1_tests" : len (stage1_configs ),
554631 "num_stage2_tests" : len (stage2_configs ),
555632 }
@@ -565,16 +642,31 @@ def evaluate_two_stage(program_path: str) -> Dict[str, float]:
565642 }
566643
567644
568- def evaluate (program_path : str ) -> Dict [str , float ]:
645+ def evaluate (program_path : str ) -> Dict [str , Union [ bool , float , str , int ] ]:
569646 """
570647 Main evaluation function - Two-stage: Correctness gate + Performance.
648+ Includes comprehensive error handling to prevent formatting errors.
571649 """
572- return evaluate_two_stage (program_path )
650+ try :
651+ return evaluate_two_stage (program_path )
652+ except Exception as e :
653+ # Catch ANY error (including formatting errors) and return safe fallback
654+ error_msg = str (e )
655+ print (f"❌ Evaluation failed with error: { error_msg } " )
656+
657+ # Return safe fallback metrics
658+ return {
659+ "stage1_passed" : False ,
660+ "stage2_score" : 0.0 ,
661+ "overall_score" : 0.0 ,
662+ "error" : error_msg ,
663+ "failed_at" : "evaluation_error" ,
664+ }
573665
574666
575667if __name__ == "__main__" :
576668 # Test the two-stage evaluator
577- print ("Testing Two-Stage Evaluator..." )
669+ print ("Testing Robust Two-Stage Evaluator..." )
578670 import os
579671
580672 initial_program_path = os .path .join (os .path .dirname (__file__ ), "initial_program.py" )
@@ -584,7 +676,8 @@ def evaluate(program_path: str) -> Dict[str, float]:
584676 print ("\n Two-Stage Evaluation Results:" )
585677 for k , v in results .items ():
586678 if isinstance (v , (int , float )):
587- print (f" { k } : { v :.4f} " )
679+ formatted_v = safe_format_number (v , ".4f" )
680+ print (f" { k } : { formatted_v } " )
588681 else :
589682 print (f" { k } : { v } " )
590683 else :
0 commit comments