2121def evolved_scaled_dot_product_attention (q , k , v , scale = 1.0 , mask = None ):
2222 """
2323 Metal Kernel-based attention implementation with working building blocks.
24-
24+
2525 This function uses simple, working Metal kernels that can be evolved
2626 to more complex optimizations. Starting simple and building complexity.
27-
27+
2828 Args:
2929 q: Query tensor [B, num_heads, L, head_dim]
30- k: Key tensor [B, num_kv_heads, L_kv, head_dim]
30+ k: Key tensor [B, num_kv_heads, L_kv, head_dim]
3131 v: Value tensor [B, num_kv_heads, L_kv, head_dim]
3232 scale: Scaling factor (typically 1/sqrt(head_dim))
3333 mask: Attention mask or mask type string
34-
34+
3535 Returns:
3636 Attention output with same shape as queries
3737 """
38-
38+
3939 # EVOLVE-BLOCK-START
4040 """
4141 WORKING METAL KERNEL IMPLEMENTATION
@@ -55,13 +55,13 @@ def evolved_scaled_dot_product_attention(q, k, v, scale=1.0, mask=None):
5555 - Implement custom softmax kernels
5656 - Eventually fuse entire attention pipeline
5757 """
58-
58+
5959 # Extract dimensions
6060 B , n_q_heads , L , head_dim = q .shape
6161 n_kv_heads = k .shape [1 ]
6262 kL = k .shape [2 ]
6363 n_repeats = n_q_heads // n_kv_heads
64-
64+
6565 # WORKING METAL KERNEL: Element-wise scaling
6666 # This is a simple, working kernel that can be evolved
6767 try :
@@ -72,32 +72,32 @@ def evolved_scaled_dot_product_attention(q, k, v, scale=1.0, mask=None):
7272 }
7373 out[elem] = q[elem] * scale_val;
7474 """
75-
75+
7676 scale_kernel = mx .fast .metal_kernel (
7777 name = "scale_query" ,
7878 input_names = ["q" , "scale_val" ],
7979 output_names = ["out" ],
8080 source = scale_source ,
8181 )
82-
82+
8383 # Create scale as a scalar array for the kernel
8484 scale_array = mx .array (float (scale ), dtype = q .dtype )
85-
85+
8686 q_scaled = scale_kernel (
8787 inputs = [q , scale_array ],
8888 template = [("T" , q .dtype )],
8989 output_shapes = [q .shape ],
9090 output_dtypes = [q .dtype ],
9191 grid = (q .size , 1 , 1 ),
92- threadgroup = (256 , 1 , 1 )
92+ threadgroup = (256 , 1 , 1 ),
9393 )[0 ]
94-
94+
9595 # Metal kernel scaling successful (remove noisy print)
96-
96+
9797 except Exception as e :
9898 # Fallback to reference implementation on any Metal kernel error
9999 q_scaled = q * scale
100-
100+
101101 # Handle GQA with reference implementation (can be evolved later)
102102 if n_repeats > 1 :
103103 q_reshaped = mx .reshape (q_scaled , [B , n_kv_heads , n_repeats , L , head_dim ])
@@ -107,11 +107,11 @@ def evolved_scaled_dot_product_attention(q, k, v, scale=1.0, mask=None):
107107 q_reshaped = q_scaled
108108 k_expanded = k
109109 v_expanded = v
110-
110+
111111 # Compute attention scores with reference implementation (can be evolved)
112112 # Evolution opportunity: Replace with custom matmul kernel
113113 scores = q_reshaped @ mx .swapaxes (k_expanded , - 1 , - 2 )
114-
114+
115115 # Apply mask with reference implementation (can be evolved)
116116 if mask is not None :
117117 if isinstance (mask , str ) and mask == "causal" :
@@ -120,7 +120,7 @@ def evolved_scaled_dot_product_attention(q, k, v, scale=1.0, mask=None):
120120 k_indices = mx .arange (kL )
121121 causal_mask = q_indices [:, None ] >= k_indices [None ]
122122 scores = mx .where (causal_mask , scores , - mx .array (np .float32 (np .inf )))
123- elif hasattr (mask , ' dtype' ) and mask .dtype == mx .bool_ :
123+ elif hasattr (mask , " dtype" ) and mask .dtype == mx .bool_ :
124124 if n_repeats > 1 and mask .ndim >= 3 :
125125 if mask .shape [- 3 ] == 1 :
126126 mask = mx .expand_dims (mask , - 3 )
@@ -129,19 +129,19 @@ def evolved_scaled_dot_product_attention(q, k, v, scale=1.0, mask=None):
129129 scores = mx .where (mask , scores , - mx .array (np .float32 (np .inf )))
130130 else :
131131 scores = scores + mask
132-
132+
133133 # Apply softmax with reference implementation (can be evolved)
134134 # Evolution opportunity: Replace with custom softmax kernel
135135 attention_weights = mx .softmax (scores , axis = - 1 , precise = True )
136-
136+
137137 # Apply attention weights to values (can be evolved)
138138 # Evolution opportunity: Replace with custom matmul kernel
139139 out = attention_weights @ v_expanded
140-
140+
141141 # Reshape back if needed
142142 if n_repeats > 1 :
143143 out = mx .reshape (out , [B , n_q_heads , L , head_dim ])
144-
144+
145145 return out
146146 # EVOLVE-BLOCK-END
147147
@@ -157,58 +157,61 @@ def create_benchmark_attention_function():
157157def test_basic_functionality ():
158158 """Test that the Metal kernel attention works with real kernels"""
159159 print ("Testing Working Metal Kernel attention functionality..." )
160-
160+
161161 # Small test case to verify kernels work
162162 B , qL , kL , D , qH , kH = 1 , 32 , 32 , 64 , 4 , 4
163163 scale = 1.0 / math .sqrt (D )
164-
164+
165165 # Create test inputs
166166 q = mx .random .normal ((B , qH , qL , D ))
167- k = mx .random .normal ((B , kH , kL , D ))
167+ k = mx .random .normal ((B , kH , kL , D ))
168168 v = mx .random .normal ((B , kH , kL , D ))
169-
169+
170170 # Test with working Metal kernel
171171 print (" Testing with working Metal scaling kernel..." )
172172 output = evolved_scaled_dot_product_attention (q , k , v , scale = scale )
173173 print (f" ✓ Working kernel test: input { q .shape } -> output { output .shape } " )
174-
174+
175175 # Test correctness by comparing with reference
176176 print (" Verifying correctness against reference implementation..." )
177177 from spda_benchmark import mlx_ref_attn
178+
178179 reference_output = mlx_ref_attn (q , k , v , scale = scale )
179-
180+
180181 # Check if outputs are close
181182 max_diff = float (mx .max (mx .abs (output - reference_output )))
182183 mse = float (mx .mean ((output - reference_output ) ** 2 ))
183-
184+
184185 print (f" ✓ Max difference vs reference: { max_diff :.2e} " )
185186 print (f" ✓ MSE vs reference: { mse :.2e} " )
186-
187+
187188 if mse < 1e-6 :
188189 print (" ✓ Accuracy test PASSED" )
189190 else :
190191 print (" ⚠️ Accuracy test FAILED - need to fix implementation" )
191-
192+
192193 # Test with different configurations
193194 test_configs = [
194- (1 , 32 , 32 , 64 , 8 , 8 , None ), # No mask
195- (1 , 64 , 64 , 64 , 8 , 8 , "causal" ), # Causal mask
196- (1 , 32 , 32 , 64 , 8 , 4 , None ), # GQA
195+ (1 , 32 , 32 , 64 , 8 , 8 , None ), # No mask
196+ (1 , 64 , 64 , 64 , 8 , 8 , "causal" ), # Causal mask
197+ (1 , 32 , 32 , 64 , 8 , 4 , None ), # GQA
197198 ]
198-
199+
199200 for B , qL , kL , D , qH , kH , mask_type in test_configs :
200201 q_test = mx .random .normal ((B , qH , qL , D ))
201202 k_test = mx .random .normal ((B , kH , kL , D ))
202203 v_test = mx .random .normal ((B , kH , kL , D ))
203-
204+
204205 try :
205206 output_test = evolved_scaled_dot_product_attention (
206207 q_test , k_test , v_test , scale = scale , mask = mask_type
207208 )
208209 print (f" ✓ Config test passed: seq={ qL } , heads={ qH } /{ kH } , mask={ mask_type } " )
209210 except Exception as e :
210- print (f" ❌ Config test failed: seq={ qL } , heads={ qH } /{ kH } , mask={ mask_type } , error={ e } " )
211-
211+ print (
212+ f" ❌ Config test failed: seq={ qL } , heads={ qH } /{ kH } , mask={ mask_type } , error={ e } "
213+ )
214+
212215 print ("🚀 Working Metal Kernel attention tests completed!" )
213216 print (" - Simple Metal scaling kernel working" )
214217 print (" - Reference implementation for complex operations" )
0 commit comments