2424def qwen3_custom_gqa_attention (queries , keys , values , scale = 1.0 , mask = None ):
2525 """
2626 Custom Metal kernel implementation for Qwen3 GQA attention.
27-
27+
2828 Args:
29- queries: [B, num_heads=40, L, head_dim=128]
29+ queries: [B, num_heads=40, L, head_dim=128]
3030 keys: [B, num_kv_heads=8, L, head_dim=128]
3131 values: [B, num_kv_heads=8, L, head_dim=128]
3232 scale: Attention scaling factor (1/sqrt(head_dim))
3333 mask: Attention mask (None, "causal", or boolean tensor)
34-
34+
3535 Returns:
3636 Attention output [B, num_heads=40, L, head_dim=128]
3737 """
38-
38+
3939 B , num_heads , L , head_dim = queries .shape
4040 _ , num_kv_heads , _ , _ = keys .shape
4141 heads_per_kv = num_heads // num_kv_heads # Should be 5 for Qwen3
42-
42+
4343 # Handle mask conversion
4444 if mask == "causal" or mask is None :
4545 # Create causal mask for autoregressive attention
@@ -56,18 +56,18 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None):
5656 else :
5757 # Fallback for unsupported mask types
5858 return mx .fast .scaled_dot_product_attention (queries , keys , values , scale = scale , mask = mask )
59-
59+
6060 # Expand mask to match batch and head dimensions if needed
6161 if mask_tensor .ndim == 2 :
6262 mask_tensor = mx .broadcast_to (mask_tensor [None , None , :, :], (B , num_heads , L , L ))
6363 elif mask_tensor .ndim == 3 :
6464 mask_tensor = mx .broadcast_to (mask_tensor [:, None , :, :], (B , num_heads , L , L ))
65-
65+
6666 # EVOLVE-BLOCK-START
67- # Custom Metal kernel source for Qwen3 GQA optimization
67+ # Fixed Metal kernel source for Qwen3 GQA optimization
6868 # This kernel leverages the 40:8 head ratio and Apple Silicon architecture
6969 kernel_source = """
70- // Qwen3 GQA Metal Kernel - Optimized for 40:8 head pattern
70+ // Fixed Qwen3 GQA Metal Kernel - Optimized for 40:8 head pattern
7171 // Thread mapping: each thread processes one query position
7272 uint thread_id = thread_position_in_grid.x;
7373 uint head_idx = thread_position_in_grid.y;
@@ -102,10 +102,10 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None):
102102
103103 const uint out_base = q_base;
104104
105- // Load query vector for this position using T4 chunks for coalesced access
106- thread T4 query_vec_chunks [HEAD_DIM / 4];
107- for (uint d_chunk = 0; d_chunk < HEAD_DIM / 4; d_chunk ++) {
108- query_vec_chunks[d_chunk ] = *(device T4*)( queries + q_base + d_chunk * 4) ;
105+ // Load query vector for this position ( using proper Metal syntax)
106+ thread T query_vec [HEAD_DIM];
107+ for (uint d = 0; d < HEAD_DIM; d ++) {
108+ query_vec[d ] = queries[ q_base + d] ;
109109 }
110110
111111 // Fused attention pass using online softmax for memory efficiency.
@@ -114,9 +114,9 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None):
114114 T denominator = T(0.0);
115115
116116 // Accumulator for the output vector, held in fast thread memory.
117- thread T4 output_accumulator[HEAD_DIM / 4 ];
118- for (uint d_chunk = 0; d_chunk < HEAD_DIM / 4 ; ++d_chunk ) {
119- output_accumulator[d_chunk ] = T4 (0.0);
117+ thread T output_accumulator[HEAD_DIM];
118+ for (uint d = 0; d < HEAD_DIM; ++d ) {
119+ output_accumulator[d ] = T (0.0);
120120 }
121121
122122 // Single pass over all key/value positions, reducing global memory traffic.
@@ -127,11 +127,25 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None):
127127 continue;
128128 }
129129
130- // Compute Q @ K^T for this key position
130+ // Compute Q @ K^T for this key position using vectorized operations
131131 const uint k_base = k_base_start + key_pos * HEAD_DIM;
132132 T score = T(0.0);
133- for (uint d_chunk = 0; d_chunk < HEAD_DIM / 4; ++d_chunk) {
134- score += dot(query_vec_chunks[d_chunk], *(device T4*)(keys + k_base + d_chunk * 4));
133+
134+ // Process 4 elements at a time for SIMD efficiency
135+ for (uint d = 0; d < HEAD_DIM; d += 4) {
136+ if (d + 3 < HEAD_DIM) {
137+ // Manual vectorization for better performance
138+ score += query_vec[d] * keys[k_base + d] +
139+ query_vec[d+1] * keys[k_base + d+1] +
140+ query_vec[d+2] * keys[k_base + d+2] +
141+ query_vec[d+3] * keys[k_base + d+3];
142+ } else {
143+ // Handle remaining elements
144+ for (uint dd = d; dd < HEAD_DIM; ++dd) {
145+ score += query_vec[dd] * keys[k_base + dd];
146+ }
147+ break;
148+ }
135149 }
136150 score *= scale_val;
137151
@@ -146,10 +160,22 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None):
146160
147161 // Load the value vector and update the output accumulator.
148162 const uint v_base = v_base_start + key_pos * HEAD_DIM;
149- for (uint d_chunk = 0; d_chunk < HEAD_DIM / 4; ++d_chunk) {
150- T4 v_chunk = *(device T4*)(values + v_base + d_chunk * 4);
151- // Rescale the existing accumulator and add the new weighted value.
152- output_accumulator[d_chunk] = output_accumulator[d_chunk] * exp_old_max_diff + exp_new_val_diff * v_chunk;
163+
164+ // Process values with manual vectorization
165+ for (uint d = 0; d < HEAD_DIM; d += 4) {
166+ if (d + 3 < HEAD_DIM) {
167+ // Rescale the existing accumulator and add the new weighted value.
168+ output_accumulator[d] = output_accumulator[d] * exp_old_max_diff + exp_new_val_diff * values[v_base + d];
169+ output_accumulator[d+1] = output_accumulator[d+1] * exp_old_max_diff + exp_new_val_diff * values[v_base + d+1];
170+ output_accumulator[d+2] = output_accumulator[d+2] * exp_old_max_diff + exp_new_val_diff * values[v_base + d+2];
171+ output_accumulator[d+3] = output_accumulator[d+3] * exp_old_max_diff + exp_new_val_diff * values[v_base + d+3];
172+ } else {
173+ // Handle remaining elements
174+ for (uint dd = d; dd < HEAD_DIM; ++dd) {
175+ output_accumulator[dd] = output_accumulator[dd] * exp_old_max_diff + exp_new_val_diff * values[v_base + dd];
176+ }
177+ break;
178+ }
153179 }
154180
155181 max_score = new_max_score;
@@ -158,34 +184,34 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None):
158184 // Final normalization and write to global memory once at the end.
159185 if (denominator > T(1e-9)) { // Use a small epsilon for stability
160186 T inv_denominator = T(1.0) / denominator;
161- for (uint d_chunk = 0; d_chunk < HEAD_DIM / 4 ; ++d_chunk ) {
162- *(device T4*)( output + out_base + d_chunk * 4) = output_accumulator[d_chunk ] * inv_denominator;
187+ for (uint d = 0; d < HEAD_DIM; ++d ) {
188+ output[ out_base + d] = output_accumulator[d ] * inv_denominator;
163189 }
164190 } else {
165191 // Handle cases where all scores were masked out; write zeros.
166- for (uint d_chunk = 0; d_chunk < HEAD_DIM / 4 ; ++d_chunk ) {
167- *(device T4*)( output + out_base + d_chunk * 4) = T4 (0.0);
192+ for (uint d = 0; d < HEAD_DIM; ++d ) {
193+ output[ out_base + d] = T (0.0);
168194 }
169195 }
170196 """
171197 # EVOLVE-BLOCK-END
172-
198+
173199 try :
174200 # Prepare kernel inputs
175201 scale_tensor = mx .array ([scale ], dtype = queries .dtype )
176202 use_mask_tensor = mx .array ([1 if use_mask else 0 ], dtype = mx .int32 )
177-
203+
178204 # Create and execute custom Metal kernel
179205 kernel = mx .fast .metal_kernel (
180206 name = "qwen3_gqa_attention_kernel" ,
181207 input_names = ["queries" , "keys" , "values" , "mask" , "scale" , "use_mask" ],
182208 output_names = ["output" ],
183209 source = kernel_source ,
184210 )
185-
211+
186212 # Optimize thread group size for Apple Silicon
187213 threadgroup_size = min (32 , L ) # Adapt to sequence length
188-
214+
189215 # Execute kernel
190216 outputs = kernel (
191217 inputs = [queries , keys , values , mask_tensor , scale_tensor , use_mask_tensor ],
@@ -203,9 +229,9 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None):
203229 ("HEADS_PER_KV" , heads_per_kv ),
204230 ],
205231 )
206-
232+
207233 return outputs [0 ]
208-
234+
209235 except Exception as e :
210236 # Fallback to standard MLX implementation if custom kernel fails
211237 print (f"⚠️ Custom GQA kernel failed: { e } , falling back to MLX SPDA" )
@@ -215,7 +241,7 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None):
215241class CustomGQAAttention (nn .Module ):
216242 """
217243 Qwen3 attention module with custom Metal kernel optimization.
218-
244+
219245 This module integrates the custom Metal kernel while maintaining
220246 compatibility with the standard MLX-LM interface.
221247 """
@@ -244,7 +270,6 @@ def __init__(self, args):
244270 # Standard MLX-LM RoPE
245271 try :
246272 from mlx_lm .models .rope_utils import initialize_rope
247-
248273 self .rope = initialize_rope (
249274 head_dim ,
250275 base = args .rope_theta ,
@@ -255,7 +280,7 @@ def __init__(self, args):
255280 except ImportError :
256281 print ("⚠️ Could not import mlx_lm rope_utils, using basic RoPE" )
257282 self .rope = None
258-
283+
259284 print (f"🔧 Initialized Custom Metal GQA Attention" )
260285 print (f" 📊 Architecture: { n_heads } :{ n_kv_heads } heads ({ n_heads // n_kv_heads } :1 ratio)" )
261286 print (f" 🎯 Head dimension: { head_dim } " )
@@ -424,11 +449,11 @@ class MockArgs:
424449 output = metal_attn (x , mask = mask )
425450
426451 print (f"✅ Metal GQA output shape: { output .shape } " )
427-
452+
428453 # Check for valid output
429454 has_nan = bool (mx .any (mx .isnan (output )))
430455 has_inf = bool (mx .any (mx .isinf (output )))
431-
456+
432457 print (f"✅ Has NaN: { has_nan } , Has Inf: { has_inf } " )
433458
434459 # Check output statistics
@@ -444,10 +469,10 @@ class MockArgs:
444469 k = mx .random .normal ((B , 8 , L , D )) # 8 KV heads
445470 v = mx .random .normal ((B , 8 , L , D ))
446471 scale = 1.0 / math .sqrt (D )
447-
472+
448473 kernel_output = qwen3_custom_gqa_attention (q , k , v , scale = scale , mask = "causal" )
449474 print (f"✅ Direct kernel output shape: { kernel_output .shape } " )
450-
475+
451476 kernel_mean = float (mx .mean (kernel_output ))
452477 kernel_std = float (mx .std (kernel_output ))
453478 print (f"✅ Direct kernel stats - Mean: { kernel_mean :.6f} , Std: { kernel_std :.6f} " )
@@ -471,7 +496,7 @@ class MockArgs:
471496 print ("Ready for Metal Kernel Evolution" )
472497 print ("Evolution focus:" )
473498 print ("1. 🔧 Metal kernel source code optimization" )
474- print ("2. 💾 Memory access pattern improvements for Apple Silicon" )
499+ print ("2. 💾 Memory access pattern improvements for Apple Silicon" )
475500 print ("3. 🎯 GQA-specific optimizations for 40:8 head ratio" )
476501 print ("4. ⚡ Vectorization and SIMD optimization" )
477502 print ("5. 🚀 Thread group and grid configuration tuning" )
0 commit comments