Skip to content

Commit 356ece2

Browse files
committed
Update best_program.py
1 parent 475f8d0 commit 356ece2

File tree

1 file changed

+66
-41
lines changed

1 file changed

+66
-41
lines changed

examples/mlx_metal_kernel_opt/best_program.py

Lines changed: 66 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,22 @@
2424
def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None):
2525
"""
2626
Custom Metal kernel implementation for Qwen3 GQA attention.
27-
27+
2828
Args:
29-
queries: [B, num_heads=40, L, head_dim=128]
29+
queries: [B, num_heads=40, L, head_dim=128]
3030
keys: [B, num_kv_heads=8, L, head_dim=128]
3131
values: [B, num_kv_heads=8, L, head_dim=128]
3232
scale: Attention scaling factor (1/sqrt(head_dim))
3333
mask: Attention mask (None, "causal", or boolean tensor)
34-
34+
3535
Returns:
3636
Attention output [B, num_heads=40, L, head_dim=128]
3737
"""
38-
38+
3939
B, num_heads, L, head_dim = queries.shape
4040
_, num_kv_heads, _, _ = keys.shape
4141
heads_per_kv = num_heads // num_kv_heads # Should be 5 for Qwen3
42-
42+
4343
# Handle mask conversion
4444
if mask == "causal" or mask is None:
4545
# Create causal mask for autoregressive attention
@@ -56,18 +56,18 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None):
5656
else:
5757
# Fallback for unsupported mask types
5858
return mx.fast.scaled_dot_product_attention(queries, keys, values, scale=scale, mask=mask)
59-
59+
6060
# Expand mask to match batch and head dimensions if needed
6161
if mask_tensor.ndim == 2:
6262
mask_tensor = mx.broadcast_to(mask_tensor[None, None, :, :], (B, num_heads, L, L))
6363
elif mask_tensor.ndim == 3:
6464
mask_tensor = mx.broadcast_to(mask_tensor[:, None, :, :], (B, num_heads, L, L))
65-
65+
6666
# EVOLVE-BLOCK-START
67-
# Custom Metal kernel source for Qwen3 GQA optimization
67+
# Fixed Metal kernel source for Qwen3 GQA optimization
6868
# This kernel leverages the 40:8 head ratio and Apple Silicon architecture
6969
kernel_source = """
70-
// Qwen3 GQA Metal Kernel - Optimized for 40:8 head pattern
70+
// Fixed Qwen3 GQA Metal Kernel - Optimized for 40:8 head pattern
7171
// Thread mapping: each thread processes one query position
7272
uint thread_id = thread_position_in_grid.x;
7373
uint head_idx = thread_position_in_grid.y;
@@ -102,10 +102,10 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None):
102102
103103
const uint out_base = q_base;
104104
105-
// Load query vector for this position using T4 chunks for coalesced access
106-
thread T4 query_vec_chunks[HEAD_DIM / 4];
107-
for (uint d_chunk = 0; d_chunk < HEAD_DIM / 4; d_chunk++) {
108-
query_vec_chunks[d_chunk] = *(device T4*)(queries + q_base + d_chunk * 4);
105+
// Load query vector for this position (using proper Metal syntax)
106+
thread T query_vec[HEAD_DIM];
107+
for (uint d = 0; d < HEAD_DIM; d++) {
108+
query_vec[d] = queries[q_base + d];
109109
}
110110
111111
// Fused attention pass using online softmax for memory efficiency.
@@ -114,9 +114,9 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None):
114114
T denominator = T(0.0);
115115
116116
// Accumulator for the output vector, held in fast thread memory.
117-
thread T4 output_accumulator[HEAD_DIM / 4];
118-
for (uint d_chunk = 0; d_chunk < HEAD_DIM / 4; ++d_chunk) {
119-
output_accumulator[d_chunk] = T4(0.0);
117+
thread T output_accumulator[HEAD_DIM];
118+
for (uint d = 0; d < HEAD_DIM; ++d) {
119+
output_accumulator[d] = T(0.0);
120120
}
121121
122122
// Single pass over all key/value positions, reducing global memory traffic.
@@ -127,11 +127,25 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None):
127127
continue;
128128
}
129129
130-
// Compute Q @ K^T for this key position
130+
// Compute Q @ K^T for this key position using vectorized operations
131131
const uint k_base = k_base_start + key_pos * HEAD_DIM;
132132
T score = T(0.0);
133-
for (uint d_chunk = 0; d_chunk < HEAD_DIM / 4; ++d_chunk) {
134-
score += dot(query_vec_chunks[d_chunk], *(device T4*)(keys + k_base + d_chunk * 4));
133+
134+
// Process 4 elements at a time for SIMD efficiency
135+
for (uint d = 0; d < HEAD_DIM; d += 4) {
136+
if (d + 3 < HEAD_DIM) {
137+
// Manual vectorization for better performance
138+
score += query_vec[d] * keys[k_base + d] +
139+
query_vec[d+1] * keys[k_base + d+1] +
140+
query_vec[d+2] * keys[k_base + d+2] +
141+
query_vec[d+3] * keys[k_base + d+3];
142+
} else {
143+
// Handle remaining elements
144+
for (uint dd = d; dd < HEAD_DIM; ++dd) {
145+
score += query_vec[dd] * keys[k_base + dd];
146+
}
147+
break;
148+
}
135149
}
136150
score *= scale_val;
137151
@@ -146,10 +160,22 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None):
146160
147161
// Load the value vector and update the output accumulator.
148162
const uint v_base = v_base_start + key_pos * HEAD_DIM;
149-
for (uint d_chunk = 0; d_chunk < HEAD_DIM / 4; ++d_chunk) {
150-
T4 v_chunk = *(device T4*)(values + v_base + d_chunk * 4);
151-
// Rescale the existing accumulator and add the new weighted value.
152-
output_accumulator[d_chunk] = output_accumulator[d_chunk] * exp_old_max_diff + exp_new_val_diff * v_chunk;
163+
164+
// Process values with manual vectorization
165+
for (uint d = 0; d < HEAD_DIM; d += 4) {
166+
if (d + 3 < HEAD_DIM) {
167+
// Rescale the existing accumulator and add the new weighted value.
168+
output_accumulator[d] = output_accumulator[d] * exp_old_max_diff + exp_new_val_diff * values[v_base + d];
169+
output_accumulator[d+1] = output_accumulator[d+1] * exp_old_max_diff + exp_new_val_diff * values[v_base + d+1];
170+
output_accumulator[d+2] = output_accumulator[d+2] * exp_old_max_diff + exp_new_val_diff * values[v_base + d+2];
171+
output_accumulator[d+3] = output_accumulator[d+3] * exp_old_max_diff + exp_new_val_diff * values[v_base + d+3];
172+
} else {
173+
// Handle remaining elements
174+
for (uint dd = d; dd < HEAD_DIM; ++dd) {
175+
output_accumulator[dd] = output_accumulator[dd] * exp_old_max_diff + exp_new_val_diff * values[v_base + dd];
176+
}
177+
break;
178+
}
153179
}
154180
155181
max_score = new_max_score;
@@ -158,34 +184,34 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None):
158184
// Final normalization and write to global memory once at the end.
159185
if (denominator > T(1e-9)) { // Use a small epsilon for stability
160186
T inv_denominator = T(1.0) / denominator;
161-
for (uint d_chunk = 0; d_chunk < HEAD_DIM / 4; ++d_chunk) {
162-
*(device T4*)(output + out_base + d_chunk * 4) = output_accumulator[d_chunk] * inv_denominator;
187+
for (uint d = 0; d < HEAD_DIM; ++d) {
188+
output[out_base + d] = output_accumulator[d] * inv_denominator;
163189
}
164190
} else {
165191
// Handle cases where all scores were masked out; write zeros.
166-
for (uint d_chunk = 0; d_chunk < HEAD_DIM / 4; ++d_chunk) {
167-
*(device T4*)(output + out_base + d_chunk * 4) = T4(0.0);
192+
for (uint d = 0; d < HEAD_DIM; ++d) {
193+
output[out_base + d] = T(0.0);
168194
}
169195
}
170196
"""
171197
# EVOLVE-BLOCK-END
172-
198+
173199
try:
174200
# Prepare kernel inputs
175201
scale_tensor = mx.array([scale], dtype=queries.dtype)
176202
use_mask_tensor = mx.array([1 if use_mask else 0], dtype=mx.int32)
177-
203+
178204
# Create and execute custom Metal kernel
179205
kernel = mx.fast.metal_kernel(
180206
name="qwen3_gqa_attention_kernel",
181207
input_names=["queries", "keys", "values", "mask", "scale", "use_mask"],
182208
output_names=["output"],
183209
source=kernel_source,
184210
)
185-
211+
186212
# Optimize thread group size for Apple Silicon
187213
threadgroup_size = min(32, L) # Adapt to sequence length
188-
214+
189215
# Execute kernel
190216
outputs = kernel(
191217
inputs=[queries, keys, values, mask_tensor, scale_tensor, use_mask_tensor],
@@ -203,9 +229,9 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None):
203229
("HEADS_PER_KV", heads_per_kv),
204230
],
205231
)
206-
232+
207233
return outputs[0]
208-
234+
209235
except Exception as e:
210236
# Fallback to standard MLX implementation if custom kernel fails
211237
print(f"⚠️ Custom GQA kernel failed: {e}, falling back to MLX SPDA")
@@ -215,7 +241,7 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None):
215241
class CustomGQAAttention(nn.Module):
216242
"""
217243
Qwen3 attention module with custom Metal kernel optimization.
218-
244+
219245
This module integrates the custom Metal kernel while maintaining
220246
compatibility with the standard MLX-LM interface.
221247
"""
@@ -244,7 +270,6 @@ def __init__(self, args):
244270
# Standard MLX-LM RoPE
245271
try:
246272
from mlx_lm.models.rope_utils import initialize_rope
247-
248273
self.rope = initialize_rope(
249274
head_dim,
250275
base=args.rope_theta,
@@ -255,7 +280,7 @@ def __init__(self, args):
255280
except ImportError:
256281
print("⚠️ Could not import mlx_lm rope_utils, using basic RoPE")
257282
self.rope = None
258-
283+
259284
print(f"🔧 Initialized Custom Metal GQA Attention")
260285
print(f" 📊 Architecture: {n_heads}:{n_kv_heads} heads ({n_heads//n_kv_heads}:1 ratio)")
261286
print(f" 🎯 Head dimension: {head_dim}")
@@ -424,11 +449,11 @@ class MockArgs:
424449
output = metal_attn(x, mask=mask)
425450

426451
print(f"✅ Metal GQA output shape: {output.shape}")
427-
452+
428453
# Check for valid output
429454
has_nan = bool(mx.any(mx.isnan(output)))
430455
has_inf = bool(mx.any(mx.isinf(output)))
431-
456+
432457
print(f"✅ Has NaN: {has_nan}, Has Inf: {has_inf}")
433458

434459
# Check output statistics
@@ -444,10 +469,10 @@ class MockArgs:
444469
k = mx.random.normal((B, 8, L, D)) # 8 KV heads
445470
v = mx.random.normal((B, 8, L, D))
446471
scale = 1.0 / math.sqrt(D)
447-
472+
448473
kernel_output = qwen3_custom_gqa_attention(q, k, v, scale=scale, mask="causal")
449474
print(f"✅ Direct kernel output shape: {kernel_output.shape}")
450-
475+
451476
kernel_mean = float(mx.mean(kernel_output))
452477
kernel_std = float(mx.std(kernel_output))
453478
print(f"✅ Direct kernel stats - Mean: {kernel_mean:.6f}, Std: {kernel_std:.6f}")
@@ -471,7 +496,7 @@ class MockArgs:
471496
print("Ready for Metal Kernel Evolution")
472497
print("Evolution focus:")
473498
print("1. 🔧 Metal kernel source code optimization")
474-
print("2. 💾 Memory access pattern improvements for Apple Silicon")
499+
print("2. 💾 Memory access pattern improvements for Apple Silicon")
475500
print("3. 🎯 GQA-specific optimizations for 40:8 head ratio")
476501
print("4. ⚡ Vectorization and SIMD optimization")
477502
print("5. 🚀 Thread group and grid configuration tuning")

0 commit comments

Comments
 (0)