Skip to content

Commit 475f8d0

Browse files
committed
d
1 parent 9778bb1 commit 475f8d0

File tree

6 files changed

+367
-239
lines changed

6 files changed

+367
-239
lines changed

examples/circle_packing_with_artifacts/evaluator.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -295,9 +295,9 @@ def evaluate(program_path):
295295
# Add successful packing stats for good solutions
296296
if valid and target_ratio > 0.95: # Near-optimal solutions
297297
artifacts["stdout"] = f"Excellent packing! Achieved {target_ratio:.1%} of target value"
298-
artifacts["radius_stats"] = (
299-
f"Min: {validation_details['min_radius']:.6f}, Max: {validation_details['max_radius']:.6f}, Avg: {validation_details['avg_radius']:.6f}"
300-
)
298+
artifacts[
299+
"radius_stats"
300+
] = f"Min: {validation_details['min_radius']:.6f}, Max: {validation_details['max_radius']:.6f}, Avg: {validation_details['avg_radius']:.6f}"
301301

302302
return EvaluationResult(
303303
metrics={
@@ -404,9 +404,9 @@ def evaluate_stage1(program_path):
404404

405405
# Add validation issues if any
406406
if not valid:
407-
artifacts["stderr"] = (
408-
f"Validation failed: {len(validation_details.get('boundary_violations', []))} boundary violations, {len(validation_details.get('overlaps', []))} overlaps"
409-
)
407+
artifacts[
408+
"stderr"
409+
] = f"Validation failed: {len(validation_details.get('boundary_violations', []))} boundary violations, {len(validation_details.get('overlaps', []))} overlaps"
410410
artifacts["failure_stage"] = "stage1_geometric_validation"
411411
if validation_details.get("boundary_violations"):
412412
artifacts["boundary_issues"] = validation_details["boundary_violations"][

examples/mlx_metal_kernel_opt/best_program.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,22 @@
2424
def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None):
2525
"""
2626
Custom Metal kernel implementation for Qwen3 GQA attention.
27-
27+
2828
Args:
29-
queries: [B, num_heads=40, L, head_dim=128]
29+
queries: [B, num_heads=40, L, head_dim=128]
3030
keys: [B, num_kv_heads=8, L, head_dim=128]
3131
values: [B, num_kv_heads=8, L, head_dim=128]
3232
scale: Attention scaling factor (1/sqrt(head_dim))
3333
mask: Attention mask (None, "causal", or boolean tensor)
34-
34+
3535
Returns:
3636
Attention output [B, num_heads=40, L, head_dim=128]
3737
"""
38-
38+
3939
B, num_heads, L, head_dim = queries.shape
4040
_, num_kv_heads, _, _ = keys.shape
4141
heads_per_kv = num_heads // num_kv_heads # Should be 5 for Qwen3
42-
42+
4343
# Handle mask conversion
4444
if mask == "causal" or mask is None:
4545
# Create causal mask for autoregressive attention
@@ -56,13 +56,13 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None):
5656
else:
5757
# Fallback for unsupported mask types
5858
return mx.fast.scaled_dot_product_attention(queries, keys, values, scale=scale, mask=mask)
59-
59+
6060
# Expand mask to match batch and head dimensions if needed
6161
if mask_tensor.ndim == 2:
6262
mask_tensor = mx.broadcast_to(mask_tensor[None, None, :, :], (B, num_heads, L, L))
6363
elif mask_tensor.ndim == 3:
6464
mask_tensor = mx.broadcast_to(mask_tensor[:, None, :, :], (B, num_heads, L, L))
65-
65+
6666
# EVOLVE-BLOCK-START
6767
# Custom Metal kernel source for Qwen3 GQA optimization
6868
# This kernel leverages the 40:8 head ratio and Apple Silicon architecture
@@ -169,23 +169,23 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None):
169169
}
170170
"""
171171
# EVOLVE-BLOCK-END
172-
172+
173173
try:
174174
# Prepare kernel inputs
175175
scale_tensor = mx.array([scale], dtype=queries.dtype)
176176
use_mask_tensor = mx.array([1 if use_mask else 0], dtype=mx.int32)
177-
177+
178178
# Create and execute custom Metal kernel
179179
kernel = mx.fast.metal_kernel(
180180
name="qwen3_gqa_attention_kernel",
181181
input_names=["queries", "keys", "values", "mask", "scale", "use_mask"],
182182
output_names=["output"],
183183
source=kernel_source,
184184
)
185-
185+
186186
# Optimize thread group size for Apple Silicon
187187
threadgroup_size = min(32, L) # Adapt to sequence length
188-
188+
189189
# Execute kernel
190190
outputs = kernel(
191191
inputs=[queries, keys, values, mask_tensor, scale_tensor, use_mask_tensor],
@@ -203,9 +203,9 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None):
203203
("HEADS_PER_KV", heads_per_kv),
204204
],
205205
)
206-
206+
207207
return outputs[0]
208-
208+
209209
except Exception as e:
210210
# Fallback to standard MLX implementation if custom kernel fails
211211
print(f"⚠️ Custom GQA kernel failed: {e}, falling back to MLX SPDA")
@@ -215,7 +215,7 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None):
215215
class CustomGQAAttention(nn.Module):
216216
"""
217217
Qwen3 attention module with custom Metal kernel optimization.
218-
218+
219219
This module integrates the custom Metal kernel while maintaining
220220
compatibility with the standard MLX-LM interface.
221221
"""
@@ -244,6 +244,7 @@ def __init__(self, args):
244244
# Standard MLX-LM RoPE
245245
try:
246246
from mlx_lm.models.rope_utils import initialize_rope
247+
247248
self.rope = initialize_rope(
248249
head_dim,
249250
base=args.rope_theta,
@@ -254,7 +255,7 @@ def __init__(self, args):
254255
except ImportError:
255256
print("⚠️ Could not import mlx_lm rope_utils, using basic RoPE")
256257
self.rope = None
257-
258+
258259
print(f"🔧 Initialized Custom Metal GQA Attention")
259260
print(f" 📊 Architecture: {n_heads}:{n_kv_heads} heads ({n_heads//n_kv_heads}:1 ratio)")
260261
print(f" 🎯 Head dimension: {head_dim}")
@@ -423,11 +424,11 @@ class MockArgs:
423424
output = metal_attn(x, mask=mask)
424425

425426
print(f"✅ Metal GQA output shape: {output.shape}")
426-
427+
427428
# Check for valid output
428429
has_nan = bool(mx.any(mx.isnan(output)))
429430
has_inf = bool(mx.any(mx.isinf(output)))
430-
431+
431432
print(f"✅ Has NaN: {has_nan}, Has Inf: {has_inf}")
432433

433434
# Check output statistics
@@ -443,10 +444,10 @@ class MockArgs:
443444
k = mx.random.normal((B, 8, L, D)) # 8 KV heads
444445
v = mx.random.normal((B, 8, L, D))
445446
scale = 1.0 / math.sqrt(D)
446-
447+
447448
kernel_output = qwen3_custom_gqa_attention(q, k, v, scale=scale, mask="causal")
448449
print(f"✅ Direct kernel output shape: {kernel_output.shape}")
449-
450+
450451
kernel_mean = float(mx.mean(kernel_output))
451452
kernel_std = float(mx.std(kernel_output))
452453
print(f"✅ Direct kernel stats - Mean: {kernel_mean:.6f}, Std: {kernel_std:.6f}")
@@ -470,7 +471,7 @@ class MockArgs:
470471
print("Ready for Metal Kernel Evolution")
471472
print("Evolution focus:")
472473
print("1. 🔧 Metal kernel source code optimization")
473-
print("2. 💾 Memory access pattern improvements for Apple Silicon")
474+
print("2. 💾 Memory access pattern improvements for Apple Silicon")
474475
print("3. 🎯 GQA-specific optimizations for 40:8 head ratio")
475476
print("4. ⚡ Vectorization and SIMD optimization")
476477
print("5. 🚀 Thread group and grid configuration tuning")

0 commit comments

Comments
 (0)