Skip to content

Commit 258f44b

Browse files
committed
fixes
1 parent bc47e1e commit 258f44b

File tree

3 files changed

+158
-69
lines changed

3 files changed

+158
-69
lines changed

examples/mlx_spda_optimization/evaluator.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def evaluate_stage1(program_path: str) -> Dict[str, float]:
321321
# Load the evolved program with better error handling
322322
spec = importlib.util.spec_from_file_location("evolved_program", program_path)
323323
evolved_program = importlib.util.module_from_spec(spec)
324-
324+
325325
try:
326326
spec.loader.exec_module(evolved_program)
327327
except SyntaxError as e:
@@ -333,14 +333,14 @@ def evaluate_stage1(program_path: str) -> Dict[str, float]:
333333
return {
334334
"basic_functionality": 0.0,
335335
"syntax_error": 1.0,
336-
"error": f"Syntax error: {str(e)}"
336+
"error": f"Syntax error: {str(e)}",
337337
}
338338
except Exception as e:
339339
print(f"[Stage 1] ❌ IMPORT ERROR: {e}")
340340
return {
341341
"basic_functionality": 0.0,
342342
"import_error": 1.0,
343-
"error": f"Import error: {str(e)}"
343+
"error": f"Import error: {str(e)}",
344344
}
345345

346346
# Check if the required function exists
@@ -384,7 +384,7 @@ def evaluate_stage1(program_path: str) -> Dict[str, float]:
384384
return {
385385
"basic_functionality": 0.0,
386386
"runtime_error": 1.0,
387-
"error": f"Runtime error: {str(e)}"
387+
"error": f"Runtime error: {str(e)}",
388388
}
389389

390390
# Enhanced scoring for incremental progress
@@ -408,7 +408,7 @@ def evaluate_stage1(program_path: str) -> Dict[str, float]:
408408
"basic_functionality": float(basic_score),
409409
"shape_correct": float(correctness["shape_correct"]),
410410
"no_nan_inf": float(correctness["no_nan_inf"]),
411-
"accuracy_score": float(min(1.0, 1.0 / max(correctness.get('mse', 1e6), 1e-6)))
411+
"accuracy_score": float(min(1.0, 1.0 / max(correctness.get("mse", 1e6), 1e-6))),
412412
}
413413

414414
print(f"[Stage 1] ✓ Completed with score: {basic_score:.3f}")
@@ -420,12 +420,9 @@ def evaluate_stage1(program_path: str) -> Dict[str, float]:
420420
except Exception as e:
421421
print(f"[Stage 1] ❌ Unexpected Exception: {str(e)}")
422422
import traceback
423+
423424
traceback.print_exc()
424-
return {
425-
"basic_functionality": 0.0,
426-
"unexpected_error": 1.0,
427-
"error": str(e)
428-
}
425+
return {"basic_functionality": 0.0, "unexpected_error": 1.0, "error": str(e)}
429426

430427

431428
def evaluate(program_path: str) -> Dict[str, float]:

examples/mlx_spda_optimization/initial_program.py

Lines changed: 40 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -21,21 +21,21 @@
2121
def evolved_scaled_dot_product_attention(q, k, v, scale=1.0, mask=None):
2222
"""
2323
Metal Kernel-based attention implementation with working building blocks.
24-
24+
2525
This function uses simple, working Metal kernels that can be evolved
2626
to more complex optimizations. Starting simple and building complexity.
27-
27+
2828
Args:
2929
q: Query tensor [B, num_heads, L, head_dim]
30-
k: Key tensor [B, num_kv_heads, L_kv, head_dim]
30+
k: Key tensor [B, num_kv_heads, L_kv, head_dim]
3131
v: Value tensor [B, num_kv_heads, L_kv, head_dim]
3232
scale: Scaling factor (typically 1/sqrt(head_dim))
3333
mask: Attention mask or mask type string
34-
34+
3535
Returns:
3636
Attention output with same shape as queries
3737
"""
38-
38+
3939
# EVOLVE-BLOCK-START
4040
"""
4141
WORKING METAL KERNEL IMPLEMENTATION
@@ -55,13 +55,13 @@ def evolved_scaled_dot_product_attention(q, k, v, scale=1.0, mask=None):
5555
- Implement custom softmax kernels
5656
- Eventually fuse entire attention pipeline
5757
"""
58-
58+
5959
# Extract dimensions
6060
B, n_q_heads, L, head_dim = q.shape
6161
n_kv_heads = k.shape[1]
6262
kL = k.shape[2]
6363
n_repeats = n_q_heads // n_kv_heads
64-
64+
6565
# WORKING METAL KERNEL: Element-wise scaling
6666
# This is a simple, working kernel that can be evolved
6767
try:
@@ -72,32 +72,32 @@ def evolved_scaled_dot_product_attention(q, k, v, scale=1.0, mask=None):
7272
}
7373
out[elem] = q[elem] * scale_val;
7474
"""
75-
75+
7676
scale_kernel = mx.fast.metal_kernel(
7777
name="scale_query",
7878
input_names=["q", "scale_val"],
7979
output_names=["out"],
8080
source=scale_source,
8181
)
82-
82+
8383
# Create scale as a scalar array for the kernel
8484
scale_array = mx.array(float(scale), dtype=q.dtype)
85-
85+
8686
q_scaled = scale_kernel(
8787
inputs=[q, scale_array],
8888
template=[("T", q.dtype)],
8989
output_shapes=[q.shape],
9090
output_dtypes=[q.dtype],
9191
grid=(q.size, 1, 1),
92-
threadgroup=(256, 1, 1)
92+
threadgroup=(256, 1, 1),
9393
)[0]
94-
94+
9595
# Metal kernel scaling successful (remove noisy print)
96-
96+
9797
except Exception as e:
9898
# Fallback to reference implementation on any Metal kernel error
9999
q_scaled = q * scale
100-
100+
101101
# Handle GQA with reference implementation (can be evolved later)
102102
if n_repeats > 1:
103103
q_reshaped = mx.reshape(q_scaled, [B, n_kv_heads, n_repeats, L, head_dim])
@@ -107,11 +107,11 @@ def evolved_scaled_dot_product_attention(q, k, v, scale=1.0, mask=None):
107107
q_reshaped = q_scaled
108108
k_expanded = k
109109
v_expanded = v
110-
110+
111111
# Compute attention scores with reference implementation (can be evolved)
112112
# Evolution opportunity: Replace with custom matmul kernel
113113
scores = q_reshaped @ mx.swapaxes(k_expanded, -1, -2)
114-
114+
115115
# Apply mask with reference implementation (can be evolved)
116116
if mask is not None:
117117
if isinstance(mask, str) and mask == "causal":
@@ -120,7 +120,7 @@ def evolved_scaled_dot_product_attention(q, k, v, scale=1.0, mask=None):
120120
k_indices = mx.arange(kL)
121121
causal_mask = q_indices[:, None] >= k_indices[None]
122122
scores = mx.where(causal_mask, scores, -mx.array(np.float32(np.inf)))
123-
elif hasattr(mask, 'dtype') and mask.dtype == mx.bool_:
123+
elif hasattr(mask, "dtype") and mask.dtype == mx.bool_:
124124
if n_repeats > 1 and mask.ndim >= 3:
125125
if mask.shape[-3] == 1:
126126
mask = mx.expand_dims(mask, -3)
@@ -129,19 +129,19 @@ def evolved_scaled_dot_product_attention(q, k, v, scale=1.0, mask=None):
129129
scores = mx.where(mask, scores, -mx.array(np.float32(np.inf)))
130130
else:
131131
scores = scores + mask
132-
132+
133133
# Apply softmax with reference implementation (can be evolved)
134134
# Evolution opportunity: Replace with custom softmax kernel
135135
attention_weights = mx.softmax(scores, axis=-1, precise=True)
136-
136+
137137
# Apply attention weights to values (can be evolved)
138138
# Evolution opportunity: Replace with custom matmul kernel
139139
out = attention_weights @ v_expanded
140-
140+
141141
# Reshape back if needed
142142
if n_repeats > 1:
143143
out = mx.reshape(out, [B, n_q_heads, L, head_dim])
144-
144+
145145
return out
146146
# EVOLVE-BLOCK-END
147147

@@ -157,58 +157,61 @@ def create_benchmark_attention_function():
157157
def test_basic_functionality():
158158
"""Test that the Metal kernel attention works with real kernels"""
159159
print("Testing Working Metal Kernel attention functionality...")
160-
160+
161161
# Small test case to verify kernels work
162162
B, qL, kL, D, qH, kH = 1, 32, 32, 64, 4, 4
163163
scale = 1.0 / math.sqrt(D)
164-
164+
165165
# Create test inputs
166166
q = mx.random.normal((B, qH, qL, D))
167-
k = mx.random.normal((B, kH, kL, D))
167+
k = mx.random.normal((B, kH, kL, D))
168168
v = mx.random.normal((B, kH, kL, D))
169-
169+
170170
# Test with working Metal kernel
171171
print(" Testing with working Metal scaling kernel...")
172172
output = evolved_scaled_dot_product_attention(q, k, v, scale=scale)
173173
print(f" ✓ Working kernel test: input {q.shape} -> output {output.shape}")
174-
174+
175175
# Test correctness by comparing with reference
176176
print(" Verifying correctness against reference implementation...")
177177
from spda_benchmark import mlx_ref_attn
178+
178179
reference_output = mlx_ref_attn(q, k, v, scale=scale)
179-
180+
180181
# Check if outputs are close
181182
max_diff = float(mx.max(mx.abs(output - reference_output)))
182183
mse = float(mx.mean((output - reference_output) ** 2))
183-
184+
184185
print(f" ✓ Max difference vs reference: {max_diff:.2e}")
185186
print(f" ✓ MSE vs reference: {mse:.2e}")
186-
187+
187188
if mse < 1e-6:
188189
print(" ✓ Accuracy test PASSED")
189190
else:
190191
print(" ⚠️ Accuracy test FAILED - need to fix implementation")
191-
192+
192193
# Test with different configurations
193194
test_configs = [
194-
(1, 32, 32, 64, 8, 8, None), # No mask
195-
(1, 64, 64, 64, 8, 8, "causal"), # Causal mask
196-
(1, 32, 32, 64, 8, 4, None), # GQA
195+
(1, 32, 32, 64, 8, 8, None), # No mask
196+
(1, 64, 64, 64, 8, 8, "causal"), # Causal mask
197+
(1, 32, 32, 64, 8, 4, None), # GQA
197198
]
198-
199+
199200
for B, qL, kL, D, qH, kH, mask_type in test_configs:
200201
q_test = mx.random.normal((B, qH, qL, D))
201202
k_test = mx.random.normal((B, kH, kL, D))
202203
v_test = mx.random.normal((B, kH, kL, D))
203-
204+
204205
try:
205206
output_test = evolved_scaled_dot_product_attention(
206207
q_test, k_test, v_test, scale=scale, mask=mask_type
207208
)
208209
print(f" ✓ Config test passed: seq={qL}, heads={qH}/{kH}, mask={mask_type}")
209210
except Exception as e:
210-
print(f" ❌ Config test failed: seq={qL}, heads={qH}/{kH}, mask={mask_type}, error={e}")
211-
211+
print(
212+
f" ❌ Config test failed: seq={qL}, heads={qH}/{kH}, mask={mask_type}, error={e}"
213+
)
214+
212215
print("🚀 Working Metal Kernel attention tests completed!")
213216
print(" - Simple Metal scaling kernel working")
214217
print(" - Reference implementation for complex operations")

0 commit comments

Comments
 (0)