2424def qwen3_custom_gqa_attention (queries , keys , values , scale = 1.0 , mask = None ):
2525 """
2626 Custom Metal kernel implementation for Qwen3 GQA attention.
27-
27+
2828 Args:
29- queries: [B, num_heads=40, L, head_dim=128]
29+ queries: [B, num_heads=40, L, head_dim=128]
3030 keys: [B, num_kv_heads=8, L, head_dim=128]
3131 values: [B, num_kv_heads=8, L, head_dim=128]
3232 scale: Attention scaling factor (1/sqrt(head_dim))
3333 mask: Attention mask (None, "causal", or boolean tensor)
34-
34+
3535 Returns:
3636 Attention output [B, num_heads=40, L, head_dim=128]
3737 """
38-
38+
3939 B , num_heads , L , head_dim = queries .shape
4040 _ , num_kv_heads , _ , _ = keys .shape
4141 heads_per_kv = num_heads // num_kv_heads # Should be 5 for Qwen3
42-
42+
4343 # Handle mask conversion
4444 if mask == "causal" or mask is None :
4545 # Create causal mask for autoregressive attention
@@ -56,13 +56,13 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None):
5656 else :
5757 # Fallback for unsupported mask types
5858 return mx .fast .scaled_dot_product_attention (queries , keys , values , scale = scale , mask = mask )
59-
59+
6060 # Expand mask to match batch and head dimensions if needed
6161 if mask_tensor .ndim == 2 :
6262 mask_tensor = mx .broadcast_to (mask_tensor [None , None , :, :], (B , num_heads , L , L ))
6363 elif mask_tensor .ndim == 3 :
6464 mask_tensor = mx .broadcast_to (mask_tensor [:, None , :, :], (B , num_heads , L , L ))
65-
65+
6666 # EVOLVE-BLOCK-START
6767 # Custom Metal kernel source for Qwen3 GQA optimization
6868 # This kernel leverages the 40:8 head ratio and Apple Silicon architecture
@@ -169,23 +169,23 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None):
169169 }
170170 """
171171 # EVOLVE-BLOCK-END
172-
172+
173173 try :
174174 # Prepare kernel inputs
175175 scale_tensor = mx .array ([scale ], dtype = queries .dtype )
176176 use_mask_tensor = mx .array ([1 if use_mask else 0 ], dtype = mx .int32 )
177-
177+
178178 # Create and execute custom Metal kernel
179179 kernel = mx .fast .metal_kernel (
180180 name = "qwen3_gqa_attention_kernel" ,
181181 input_names = ["queries" , "keys" , "values" , "mask" , "scale" , "use_mask" ],
182182 output_names = ["output" ],
183183 source = kernel_source ,
184184 )
185-
185+
186186 # Optimize thread group size for Apple Silicon
187187 threadgroup_size = min (32 , L ) # Adapt to sequence length
188-
188+
189189 # Execute kernel
190190 outputs = kernel (
191191 inputs = [queries , keys , values , mask_tensor , scale_tensor , use_mask_tensor ],
@@ -203,9 +203,9 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None):
203203 ("HEADS_PER_KV" , heads_per_kv ),
204204 ],
205205 )
206-
206+
207207 return outputs [0 ]
208-
208+
209209 except Exception as e :
210210 # Fallback to standard MLX implementation if custom kernel fails
211211 print (f"⚠️ Custom GQA kernel failed: { e } , falling back to MLX SPDA" )
@@ -215,7 +215,7 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None):
215215class CustomGQAAttention (nn .Module ):
216216 """
217217 Qwen3 attention module with custom Metal kernel optimization.
218-
218+
219219 This module integrates the custom Metal kernel while maintaining
220220 compatibility with the standard MLX-LM interface.
221221 """
@@ -244,6 +244,7 @@ def __init__(self, args):
244244 # Standard MLX-LM RoPE
245245 try :
246246 from mlx_lm .models .rope_utils import initialize_rope
247+
247248 self .rope = initialize_rope (
248249 head_dim ,
249250 base = args .rope_theta ,
@@ -254,7 +255,7 @@ def __init__(self, args):
254255 except ImportError :
255256 print ("⚠️ Could not import mlx_lm rope_utils, using basic RoPE" )
256257 self .rope = None
257-
258+
258259 print (f"🔧 Initialized Custom Metal GQA Attention" )
259260 print (f" 📊 Architecture: { n_heads } :{ n_kv_heads } heads ({ n_heads // n_kv_heads } :1 ratio)" )
260261 print (f" 🎯 Head dimension: { head_dim } " )
@@ -423,11 +424,11 @@ class MockArgs:
423424 output = metal_attn (x , mask = mask )
424425
425426 print (f"✅ Metal GQA output shape: { output .shape } " )
426-
427+
427428 # Check for valid output
428429 has_nan = bool (mx .any (mx .isnan (output )))
429430 has_inf = bool (mx .any (mx .isinf (output )))
430-
431+
431432 print (f"✅ Has NaN: { has_nan } , Has Inf: { has_inf } " )
432433
433434 # Check output statistics
@@ -443,10 +444,10 @@ class MockArgs:
443444 k = mx .random .normal ((B , 8 , L , D )) # 8 KV heads
444445 v = mx .random .normal ((B , 8 , L , D ))
445446 scale = 1.0 / math .sqrt (D )
446-
447+
447448 kernel_output = qwen3_custom_gqa_attention (q , k , v , scale = scale , mask = "causal" )
448449 print (f"✅ Direct kernel output shape: { kernel_output .shape } " )
449-
450+
450451 kernel_mean = float (mx .mean (kernel_output ))
451452 kernel_std = float (mx .std (kernel_output ))
452453 print (f"✅ Direct kernel stats - Mean: { kernel_mean :.6f} , Std: { kernel_std :.6f} " )
@@ -470,7 +471,7 @@ class MockArgs:
470471 print ("Ready for Metal Kernel Evolution" )
471472 print ("Evolution focus:" )
472473 print ("1. 🔧 Metal kernel source code optimization" )
473- print ("2. 💾 Memory access pattern improvements for Apple Silicon" )
474+ print ("2. 💾 Memory access pattern improvements for Apple Silicon" )
474475 print ("3. 🎯 GQA-specific optimizations for 40:8 head ratio" )
475476 print ("4. ⚡ Vectorization and SIMD optimization" )
476477 print ("5. 🚀 Thread group and grid configuration tuning" )
0 commit comments