Skip to content

Commit 4f4a500

Browse files
committed
Update evaluator.py
1 parent 9b90c69 commit 4f4a500

File tree

1 file changed

+18
-18
lines changed

1 file changed

+18
-18
lines changed

examples/mlx_metal_kernel_opt/evaluator.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -517,14 +517,14 @@ def _calculate_performance_metrics(self, results: List[BenchmarkResult]) -> Dict
517517
memories = [r.peak_memory_gb for r in results if r.peak_memory_gb > 0]
518518

519519
return {
520-
'avg_decode_speed': np.mean(decode_speeds) if decode_speeds else 0,
521-
'min_decode_speed': np.min(decode_speeds) if decode_speeds else 0,
522-
'max_decode_speed': np.max(decode_speeds) if decode_speeds else 0,
523-
'avg_prefill_speed': np.mean(prefill_speeds) if prefill_speeds else 0,
524-
'avg_memory_gb': np.mean(memories) if memories else 0,
525-
'max_memory_gb': np.max(memories) if memories else 0,
526-
'num_successful_tests': len(results),
527-
'decode_speed_std': np.std(decode_speeds) if len(decode_speeds) > 1 else 0
520+
'avg_decode_speed': float(np.mean(decode_speeds)) if decode_speeds else 0.0,
521+
'min_decode_speed': float(np.min(decode_speeds)) if decode_speeds else 0.0,
522+
'max_decode_speed': float(np.max(decode_speeds)) if decode_speeds else 0.0,
523+
'avg_prefill_speed': float(np.mean(prefill_speeds)) if prefill_speeds else 0.0,
524+
'avg_memory_gb': float(np.mean(memories)) if memories else 0.0,
525+
'max_memory_gb': float(np.max(memories)) if memories else 0.0,
526+
'num_successful_tests': int(len(results)),
527+
'decode_speed_std': float(np.std(decode_speeds)) if len(decode_speeds) > 1 else 0.0
528528
}
529529

530530
def _calculate_final_score(self, performance: Dict[str, float], correctness: float) -> float:
@@ -591,10 +591,10 @@ def _compare_to_baseline(self, performance: Dict[str, float]) -> Dict[str, float
591591
current_decode = performance['avg_decode_speed']
592592

593593
return {
594-
'decode_improvement_pct': ((current_decode - baseline_decode) / baseline_decode) * 100,
595-
'decode_improvement_absolute': current_decode - baseline_decode,
596-
'memory_change_gb': performance['avg_memory_gb'] - self.baseline_metrics['avg_memory_gb'],
597-
'target_achieved': current_decode >= 80.0, # 80+ tokens/sec target
594+
'decode_improvement_pct': float(((current_decode - baseline_decode) / baseline_decode) * 100),
595+
'decode_improvement_absolute': float(current_decode - baseline_decode),
596+
'memory_change_gb': float(performance['avg_memory_gb'] - self.baseline_metrics['avg_memory_gb']),
597+
'target_achieved': bool(current_decode >= 80.0), # 80+ tokens/sec target
598598
}
599599

600600
def _generate_summary(self, performance: Dict[str, float], correctness: float) -> str:
@@ -656,12 +656,12 @@ def _create_failure_result(self, error_message: str) -> Dict[str, Any]:
656656
def _result_to_dict(self, result: BenchmarkResult) -> Dict:
657657
"""Convert BenchmarkResult to dictionary"""
658658
return {
659-
'name': result.name,
660-
'decode_tokens_per_sec': result.decode_tokens_per_sec,
661-
'prefill_tokens_per_sec': result.prefill_tokens_per_sec,
662-
'peak_memory_gb': result.peak_memory_gb,
663-
'generated_tokens': result.generated_tokens,
664-
'total_time_sec': result.total_time_sec
659+
'name': str(result.name),
660+
'decode_tokens_per_sec': float(result.decode_tokens_per_sec),
661+
'prefill_tokens_per_sec': float(result.prefill_tokens_per_sec),
662+
'peak_memory_gb': float(result.peak_memory_gb),
663+
'generated_tokens': int(result.generated_tokens),
664+
'total_time_sec': float(result.total_time_sec)
665665
}
666666

667667

0 commit comments

Comments
 (0)