Skip to content

Commit c1e9a02

Browse files
committed
s
1 parent 82c2796 commit c1e9a02

File tree

3 files changed

+17
-7
lines changed

3 files changed

+17
-7
lines changed

examples/mlx_metal_kernel_opt/best_program.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None):
5555
use_mask = True
5656
else:
5757
# Raise error for unsupported mask types - no fallback
58-
raise ValueError(f"Unsupported mask type: {type(mask)}. Custom kernel requires None, 'causal', or mx.array mask.")
58+
raise ValueError(
59+
f"Unsupported mask type: {type(mask)}. Custom kernel requires None, 'causal', or mx.array mask."
60+
)
5961

6062
# Expand mask to match batch and head dimensions if needed
6163
if mask_tensor.ndim == 2:

examples/mlx_metal_kernel_opt/initial_program.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None):
5555
use_mask = True
5656
else:
5757
# Raise error for unsupported mask types - no fallback
58-
raise ValueError(f"Unsupported mask type: {type(mask)}. Custom kernel requires None, 'causal', or mx.array mask.")
58+
raise ValueError(
59+
f"Unsupported mask type: {type(mask)}. Custom kernel requires None, 'causal', or mx.array mask."
60+
)
5961

6062
# Expand mask to match batch and head dimensions if needed
6163
if mask_tensor.ndim == 2:

examples/mlx_metal_kernel_opt/run_benchmarks.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,16 +106,20 @@ def run_optimized_benchmark(args, original_dir):
106106
try:
107107
# Import the optimized attention implementation
108108
# First, try the OpenEvolve output directory (most likely location)
109-
best_program_path = os.path.join(original_dir, "openevolve_output", "best", "best_program.py")
110-
109+
best_program_path = os.path.join(
110+
original_dir, "openevolve_output", "best", "best_program.py"
111+
)
112+
111113
# Fallback to root directory if not found in openevolve_output
112114
if not os.path.exists(best_program_path):
113115
best_program_path = os.path.join(original_dir, "best_program.py")
114-
116+
115117
if not os.path.exists(best_program_path):
116118
print(f"❌ Error: Optimized program not found")
117119
print("Searched in the following locations:")
118-
print(f" 1. {os.path.join(original_dir, 'openevolve_output', 'best', 'best_program.py')}")
120+
print(
121+
f" 1. {os.path.join(original_dir, 'openevolve_output', 'best', 'best_program.py')}"
122+
)
119123
print(f" 2. {os.path.join(original_dir, 'best_program.py')}")
120124
print("Please ensure OpenEvolve has generated an optimized solution")
121125
print("Expected path: ./openevolve_output/best/best_program.py")
@@ -454,7 +458,9 @@ def print_comparison_summary(comparison_results):
454458

455459
print(f"\n📊 ABSOLUTE PERFORMANCE:")
456460
print(f" 🔵 Standard MLX-LM: {summary['avg_standard_decode_speed']:.1f} tokens/sec average")
457-
print(f" 🟠 Metal Kernel Optimized: {summary['avg_optimized_decode_speed']:.1f} tokens/sec average")
461+
print(
462+
f" 🟠 Metal Kernel Optimized: {summary['avg_optimized_decode_speed']:.1f} tokens/sec average"
463+
)
458464
print(
459465
f" 📈 Net Improvement: {summary['avg_optimized_decode_speed'] - summary['avg_standard_decode_speed']:+.1f} tokens/sec"
460466
)

0 commit comments

Comments
 (0)