Skip to content

Commit 0eae54e

Browse files
committed
Update evaluator.py
1 parent f02c217 commit 0eae54e

File tree

1 file changed

+179
-86
lines changed

1 file changed

+179
-86
lines changed

examples/mlx_spda_optimization/evaluator.py

Lines changed: 179 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Two-Stage Evaluator for MLX Block Diagonal Attention Optimization
2+
Robust Two-Stage Evaluator for MLX Block Diagonal Attention Optimization
33
44
STAGE 1: Correctness & Compatibility Gate
55
- Ensures evolved programs produce correct outputs
@@ -20,7 +20,7 @@
2020
import math
2121
import time
2222
import traceback
23-
from typing import Dict, List, Tuple
23+
from typing import Dict, List, Tuple, Union
2424
import gc
2525

2626
import mlx.core as mx
@@ -30,6 +30,46 @@
3030
from spda_benchmark import prepare_inputs, mlx_ref_attn, mlx_fused_attn, do_attention, bench
3131

3232

33+
def safe_format_percentage(value, fallback="N/A%"):
34+
"""
35+
Safely format a value as a percentage.
36+
37+
Args:
38+
value: Value to format as percentage (should be between 0 and 1)
39+
fallback: Fallback string if formatting fails
40+
41+
Returns:
42+
Formatted percentage string
43+
"""
44+
try:
45+
if isinstance(value, (int, float)) and not math.isnan(value) and not math.isinf(value):
46+
return f"{value:.1%}"
47+
else:
48+
return fallback
49+
except (ValueError, TypeError):
50+
return fallback
51+
52+
53+
def safe_format_number(value: Union[float, int, str], format_spec: str = ".3f", fallback: str = "N/A") -> str:
54+
"""
55+
Safely format a number with fallback for non-numeric values.
56+
This prevents "Unknown format code 'f' for object of type 'str'" errors.
57+
"""
58+
try:
59+
if isinstance(value, (int, float)) and not math.isnan(value) and not math.isinf(value):
60+
return f"{value:{format_spec}}"
61+
elif value == float("inf"):
62+
return "∞"
63+
elif value == float("-inf"):
64+
return "-∞"
65+
elif isinstance(value, float) and math.isnan(value):
66+
return "NaN"
67+
else:
68+
return str(value) if value is not None else fallback
69+
except (ValueError, TypeError):
70+
return fallback
71+
72+
3373
def create_stage1_test_configurations() -> List[Dict]:
3474
"""
3575
Stage 1: Comprehensive correctness tests based on spda_benchmark.
@@ -141,45 +181,59 @@ def create_stage2_performance_configurations() -> List[Dict]:
141181
def compare_attention_outputs(output1: mx.array, output2: mx.array, tolerance: float = 1e-3) -> Dict[str, float]:
142182
"""
143183
Compare two attention outputs with appropriate tolerance.
144-
Enhanced version from original evaluator.
184+
Enhanced version with robust error handling.
145185
"""
146-
# Ensure arrays are evaluated
147-
output1 = mx.array(output1)
148-
output2 = mx.array(output2)
149-
mx.eval(output1, output2)
150-
151-
# Calculate various similarity metrics
152-
diff = output1 - output2
153-
mse = float(mx.mean(diff**2))
154-
mae = float(mx.mean(mx.abs(diff)))
155-
max_diff = float(mx.max(mx.abs(diff)))
156-
157-
# Relative error (normalized by output magnitude)
158-
output1_norm = float(mx.sqrt(mx.mean(output1**2)))
159-
relative_error = float(mx.sqrt(mx.mean(diff**2))) / max(output1_norm, 1e-8)
160-
161-
# Check MLX's allclose function
162-
allclose_result = bool(mx.allclose(output1, output2, atol=tolerance, rtol=tolerance))
163-
164-
# Additional robust check: if MSE is extremely small, consider it a match
165-
mse_perfect = mse < 1e-8
166-
167-
# Final decision: either allclose passes OR MSE is extremely small
168-
final_allclose = allclose_result or mse_perfect
169-
170-
return {
171-
"mse": mse,
172-
"mae": mae,
173-
"max_diff": max_diff,
174-
"relative_error": relative_error,
175-
"allclose": final_allclose,
176-
"allclose_strict": allclose_result,
177-
"mse_perfect": mse_perfect,
178-
"tolerance_used": tolerance,
179-
}
180-
181-
182-
def evaluate_stage1_correctness(evolved_attention_fn, config: Dict) -> Dict[str, float]:
186+
try:
187+
# Ensure arrays are evaluated
188+
output1 = mx.array(output1)
189+
output2 = mx.array(output2)
190+
mx.eval(output1, output2)
191+
192+
# Calculate various similarity metrics
193+
diff = output1 - output2
194+
mse = float(mx.mean(diff**2))
195+
mae = float(mx.mean(mx.abs(diff)))
196+
max_diff = float(mx.max(mx.abs(diff)))
197+
198+
# Relative error (normalized by output magnitude)
199+
output1_norm = float(mx.sqrt(mx.mean(output1**2)))
200+
relative_error = float(mx.sqrt(mx.mean(diff**2))) / max(output1_norm, 1e-8)
201+
202+
# Check MLX's allclose function
203+
allclose_result = bool(mx.allclose(output1, output2, atol=tolerance, rtol=tolerance))
204+
205+
# Additional robust check: if MSE is extremely small, consider it a match
206+
mse_perfect = mse < 1e-8
207+
208+
# Final decision: either allclose passes OR MSE is extremely small
209+
final_allclose = allclose_result or mse_perfect
210+
211+
return {
212+
"mse": mse,
213+
"mae": mae,
214+
"max_diff": max_diff,
215+
"relative_error": relative_error,
216+
"allclose": final_allclose,
217+
"allclose_strict": allclose_result,
218+
"mse_perfect": mse_perfect,
219+
"tolerance_used": tolerance,
220+
}
221+
except Exception as e:
222+
# Fallback values if comparison fails
223+
return {
224+
"mse": float("inf"),
225+
"mae": float("inf"),
226+
"max_diff": float("inf"),
227+
"relative_error": float("inf"),
228+
"allclose": False,
229+
"allclose_strict": False,
230+
"mse_perfect": False,
231+
"tolerance_used": tolerance,
232+
"comparison_error": str(e),
233+
}
234+
235+
236+
def evaluate_stage1_correctness(evolved_attention_fn, config: Dict) -> Dict[str, Union[bool, float, str]]:
183237
"""
184238
Stage 1: Test correctness with category-appropriate tolerances.
185239
@@ -244,7 +298,7 @@ def evaluate_stage1_correctness(evolved_attention_fn, config: Dict) -> Dict[str,
244298
# For shorter sequences, compute reference for comparison
245299
try:
246300
reference_output = mlx_ref_attn(q, k, v, scale=scale, mask=mask)
247-
except Exception:
301+
except Exception as ref_error:
248302
# Reference failed, check structural validity only
249303
has_nan = bool(mx.any(mx.isnan(evolved_output)))
250304
has_inf = bool(mx.any(mx.isinf(evolved_output)))
@@ -258,7 +312,7 @@ def evaluate_stage1_correctness(evolved_attention_fn, config: Dict) -> Dict[str,
258312
"tolerance_used": tolerance,
259313
"category": category,
260314
"reference_computed": False,
261-
"reference_error": "Reference computation failed",
315+
"reference_error": str(ref_error),
262316
}
263317

264318
# Compare outputs with category-appropriate tolerance
@@ -293,7 +347,7 @@ def evaluate_stage1_correctness(evolved_attention_fn, config: Dict) -> Dict[str,
293347
}
294348

295349

296-
def benchmark_performance(evolved_fn, config: Dict, num_trials: int = 3) -> Dict[str, float]:
350+
def benchmark_performance(evolved_fn, config: Dict, num_trials: int = 3) -> Dict[str, Union[float, str]]:
297351
"""
298352
Stage 2: Benchmark performance vs mx.fast.scaled_dot_product_attention.
299353
"""
@@ -388,7 +442,7 @@ def benchmark_performance(evolved_fn, config: Dict, num_trials: int = 3) -> Dict
388442
return {"speedup": 0.0, "performance_score": 0.0, "error": str(e)}
389443

390444

391-
def evaluate_two_stage(program_path: str) -> Dict[str, float]:
445+
def evaluate_two_stage(program_path: str) -> Dict[str, Union[bool, float, str, int]]:
392446
"""
393447
Two-stage evaluation: Correctness gate + Performance optimization.
394448
"""
@@ -431,19 +485,25 @@ def evaluate_two_stage(program_path: str) -> Dict[str, float]:
431485
if result["passed"]:
432486
stage1_passed_count += 1
433487
mse_val = result.get('mse', 'N/A')
434-
if isinstance(mse_val, (int, float)) and not math.isnan(mse_val) and not math.isinf(mse_val):
435-
mse_str = f"{mse_val:.2e}"
436-
else:
437-
mse_str = str(mse_val)
488+
mse_str = safe_format_number(mse_val, ".2e")
438489
print(f" ✅ PASSED: MSE={mse_str}")
439490
else:
440-
print(f" ❌ FAILED: {result.get('error', 'Accuracy/structure issue')}")
491+
error_msg = result.get('error', 'Accuracy/structure issue')
492+
print(f" ❌ FAILED: {error_msg}")
441493

442-
stage1_pass_rate = stage1_passed_count / len(stage1_configs)
494+
# Safe calculation of stage1_pass_rate to prevent division errors
495+
try:
496+
stage1_pass_rate = stage1_passed_count / len(stage1_configs) if len(stage1_configs) > 0 else 0.0
497+
except (TypeError, ZeroDivisionError):
498+
stage1_pass_rate = 0.0
499+
443500
stage1_passed = stage1_pass_rate >= 0.9 # 90% pass rate required
444501

502+
# Safe formatting for stage1_pass_rate
503+
stage1_pass_rate_str = safe_format_percentage(stage1_pass_rate)
504+
445505
print(f"\n📊 STAGE 1 Results:")
446-
print(f" Passed: {stage1_passed_count}/{len(stage1_configs)} ({stage1_pass_rate:.1%})")
506+
print(f" Passed: {stage1_passed_count}/{len(stage1_configs)} ({stage1_pass_rate_str})")
447507
print(f" Gate Status: {'✅ PASSED' if stage1_passed else '❌ FAILED'}")
448508

449509
if not stage1_passed:
@@ -484,41 +544,44 @@ def evaluate_two_stage(program_path: str) -> Dict[str, float]:
484544
"weighted_score": weighted_score,
485545
})
486546

487-
# Safe formatting for speedup
488-
if isinstance(speedup, (int, float)) and not math.isnan(speedup) and not math.isinf(speedup):
489-
speedup_str = f"{speedup:.2f}"
490-
elif speedup == float("inf"):
491-
speedup_str = "∞"
492-
else:
493-
speedup_str = str(speedup)
494-
495-
if isinstance(perf_score, (int, float)) and not math.isnan(perf_score) and not math.isinf(perf_score):
496-
perf_str = f"{perf_score:.3f}"
497-
else:
498-
perf_str = str(perf_score)
547+
# Safe formatting for speedup and performance score
548+
speedup_str = safe_format_number(speedup, ".2f")
549+
perf_str = safe_format_number(perf_score, ".3f")
499550

500551
print(f" 📊 Speedup: {speedup_str}x, Score: {perf_str}")
501552

502-
stage2_score = total_weighted_score / total_weight if total_weight > 0 else 0.0
503-
553+
# Safe calculation of stage2_score to prevent division errors
554+
try:
555+
stage2_score = total_weighted_score / total_weight if total_weight > 0 else 0.0
556+
except (TypeError, ZeroDivisionError):
557+
stage2_score = 0.0
558+
504559
# Calculate overall score (Stage 1 gate + Stage 2 performance)
505560
overall_score = stage2_score # Since Stage 1 is just a gate
506561

507-
# Detailed performance analysis
508-
speedups = [r["benchmark"]["speedup"] for r in stage2_results
509-
if isinstance(r["benchmark"]["speedup"], (int, float)) and
510-
r["benchmark"]["speedup"] != float("inf") and
511-
not math.isnan(r["benchmark"]["speedup"])]
512-
avg_speedup = np.mean(speedups) if speedups else 0.0
513-
max_speedup = max(speedups) if speedups else 0.0
562+
# Detailed performance analysis with safe operations
563+
speedups = []
564+
for r in stage2_results:
565+
speedup_val = r["benchmark"]["speedup"]
566+
if (isinstance(speedup_val, (int, float)) and
567+
speedup_val != float("inf") and
568+
not math.isnan(speedup_val)):
569+
speedups.append(speedup_val)
570+
571+
try:
572+
avg_speedup = np.mean(speedups) if speedups else 0.0
573+
max_speedup = max(speedups) if speedups else 0.0
574+
except (TypeError, ValueError):
575+
avg_speedup = 0.0
576+
max_speedup = 0.0
514577

515578
print(f"\n📈 STAGE 2 Results:")
516579

517-
# Safe formatting
518-
stage2_str = f"{stage2_score:.3f}" if isinstance(stage2_score, (int, float)) else str(stage2_score)
519-
avg_speedup_str = f"{avg_speedup:.2f}" if isinstance(avg_speedup, (int, float)) else str(avg_speedup)
520-
max_speedup_str = f"{max_speedup:.2f}" if isinstance(max_speedup, (int, float)) else str(max_speedup)
521-
overall_str = f"{overall_score:.3f}" if isinstance(overall_score, (int, float)) else str(overall_score)
580+
# Safe formatting for final results
581+
stage2_str = safe_format_number(stage2_score, ".3f")
582+
avg_speedup_str = safe_format_number(avg_speedup, ".2f")
583+
max_speedup_str = safe_format_number(max_speedup, ".2f")
584+
overall_str = safe_format_number(overall_score, ".3f")
522585

523586
print(f" Performance Score: {stage2_str}")
524587
print(f" Average Speedup: {avg_speedup_str}x")
@@ -538,18 +601,32 @@ def evaluate_two_stage(program_path: str) -> Dict[str, float]:
538601
else:
539602
print(f" ❌ POOR: Need significant optimization")
540603

604+
# Ensure all return values are safe numeric types
605+
try:
606+
safe_stage1_pass_rate = float(stage1_pass_rate) if isinstance(stage1_pass_rate, (int, float)) else 0.0
607+
safe_stage2_score = float(stage2_score) if isinstance(stage2_score, (int, float)) else 0.0
608+
safe_overall_score = float(overall_score) if isinstance(overall_score, (int, float)) else 0.0
609+
safe_avg_speedup = float(avg_speedup) if isinstance(avg_speedup, (int, float)) else 0.0
610+
safe_max_speedup = float(max_speedup) if isinstance(max_speedup, (int, float)) else 0.0
611+
except (TypeError, ValueError):
612+
safe_stage1_pass_rate = 0.0
613+
safe_stage2_score = 0.0
614+
safe_overall_score = 0.0
615+
safe_avg_speedup = 0.0
616+
safe_max_speedup = 0.0
617+
541618
return {
542619
# Gate results
543620
"stage1_passed": stage1_passed,
544-
"stage1_pass_rate": stage1_pass_rate,
621+
"stage1_pass_rate": safe_stage1_pass_rate,
545622

546623
# Performance results
547-
"stage2_score": float(stage2_score),
548-
"overall_score": float(overall_score),
624+
"stage2_score": safe_stage2_score,
625+
"overall_score": safe_overall_score,
549626

550627
# Detailed metrics
551-
"avg_speedup": float(avg_speedup),
552-
"max_speedup": float(max_speedup),
628+
"avg_speedup": safe_avg_speedup,
629+
"max_speedup": safe_max_speedup,
553630
"num_stage1_tests": len(stage1_configs),
554631
"num_stage2_tests": len(stage2_configs),
555632
}
@@ -565,16 +642,31 @@ def evaluate_two_stage(program_path: str) -> Dict[str, float]:
565642
}
566643

567644

568-
def evaluate(program_path: str) -> Dict[str, float]:
645+
def evaluate(program_path: str) -> Dict[str, Union[bool, float, str, int]]:
569646
"""
570647
Main evaluation function - Two-stage: Correctness gate + Performance.
648+
Includes comprehensive error handling to prevent formatting errors.
571649
"""
572-
return evaluate_two_stage(program_path)
650+
try:
651+
return evaluate_two_stage(program_path)
652+
except Exception as e:
653+
# Catch ANY error (including formatting errors) and return safe fallback
654+
error_msg = str(e)
655+
print(f"❌ Evaluation failed with error: {error_msg}")
656+
657+
# Return safe fallback metrics
658+
return {
659+
"stage1_passed": False,
660+
"stage2_score": 0.0,
661+
"overall_score": 0.0,
662+
"error": error_msg,
663+
"failed_at": "evaluation_error",
664+
}
573665

574666

575667
if __name__ == "__main__":
576668
# Test the two-stage evaluator
577-
print("Testing Two-Stage Evaluator...")
669+
print("Testing Robust Two-Stage Evaluator...")
578670
import os
579671

580672
initial_program_path = os.path.join(os.path.dirname(__file__), "initial_program.py")
@@ -584,7 +676,8 @@ def evaluate(program_path: str) -> Dict[str, float]:
584676
print("\nTwo-Stage Evaluation Results:")
585677
for k, v in results.items():
586678
if isinstance(v, (int, float)):
587-
print(f" {k}: {v:.4f}")
679+
formatted_v = safe_format_number(v, ".4f")
680+
print(f" {k}: {formatted_v}")
588681
else:
589682
print(f" {k}: {v}")
590683
else:

0 commit comments

Comments
 (0)