Skip to content

Commit 6ab66ed

Browse files
committed
t
1 parent f3174f9 commit 6ab66ed

File tree

3 files changed

+52
-39
lines changed

3 files changed

+52
-39
lines changed

examples/mlx_attention_optimization/config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,5 +86,5 @@ evaluator:
8686

8787
# Evolution settings for attention optimization
8888
diff_based_evolution: true
89-
allow_full_rewrites: true # Allow full rewrites for significant attention improvements
89+
allow_full_rewrites: false # Allow full rewrites for significant attention improvements
9090
max_code_length: 100000 # Larger for complex attention implementations

examples/mlx_attention_optimization/evaluator.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,28 @@
1818
import os
1919

2020

21+
def safe_float_conversion(value, default=0.0):
22+
"""Safely convert a value to float, handling infinity and NaN"""
23+
try:
24+
float_val = float(value)
25+
if np.isnan(float_val) or np.isinf(float_val):
26+
return default
27+
return float_val
28+
except (TypeError, ValueError, OverflowError):
29+
return default
30+
31+
32+
def safe_division(numerator, denominator, default=0.0):
33+
"""Safely perform division, handling zero denominators and infinity"""
34+
try:
35+
if denominator == 0 or denominator is None:
36+
return default
37+
result = numerator / denominator
38+
return safe_float_conversion(result, default)
39+
except (TypeError, ValueError, OverflowError, ZeroDivisionError):
40+
return default
41+
42+
2143
def run_with_timeout(func, args=(), kwargs={}, timeout_seconds=60):
2244
"""Run a function with timeout using concurrent.futures"""
2345
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
@@ -280,19 +302,19 @@ def benchmark_model_inference(program, config: Dict[str, Any]) -> Dict[str, Any]
280302
opt_time = np.mean(opt_times)
281303

282304
# Calculate speedup
283-
speedup = ref_time / opt_time if opt_time > 0 else 0.0
305+
speedup = safe_division(ref_time, opt_time, 0.0)
284306

285-
# Calculate throughput (tokens/second)
307+
# Calculate throughput (tokens/second)
286308
total_tokens = batch_size * seq_len
287-
ref_throughput = total_tokens / ref_time
288-
opt_throughput = total_tokens / opt_time
309+
ref_throughput = safe_division(total_tokens, ref_time, 0.0)
310+
opt_throughput = safe_division(total_tokens, opt_time, 0.0)
289311

290312
results[config_name] = {
291-
"reference_time": ref_time,
292-
"optimized_time": opt_time,
293-
"speedup": speedup,
294-
"ref_throughput": ref_throughput,
295-
"opt_throughput": opt_throughput,
313+
"reference_time": safe_float_conversion(ref_time),
314+
"optimized_time": safe_float_conversion(opt_time),
315+
"speedup": safe_float_conversion(speedup),
316+
"ref_throughput": safe_float_conversion(ref_throughput),
317+
"opt_throughput": safe_float_conversion(opt_throughput),
296318
"model_config": model_config
297319
}
298320

examples/mlx_attention_optimization/initial_program.py

Lines changed: 20 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,9 @@ def optimized_attention_kernel(
8080
if value.dtype != mx.float32:
8181
value = value.astype(mx.float32)
8282

83-
# Determine scale factor
83+
# Determine scale factor - make sure it matches reference implementation
8484
if scale_strategy == "sqrt_dk":
85-
scale = 1.0 / math.sqrt(d_model)
85+
scale = 1.0 / math.sqrt(d_model) # This should match reference
8686
elif scale_strategy == "learned":
8787
# Slightly different scale as a heuristic
8888
scale = 0.9 / math.sqrt(d_model)
@@ -92,24 +92,20 @@ def optimized_attention_kernel(
9292
# For now, implement basic attention to ensure correctness
9393
# More complex optimizations will be evolved
9494

95-
# Compute attention scores
96-
scores = mx.matmul(query, mx.transpose(key, axes=(0, 2, 1)))
97-
98-
# Apply scaling
99-
scores = scores * scale
95+
# Compute attention scores - match reference implementation exactly
96+
if scale_strategy == "sqrt_dk":
97+
# Match reference exactly: scores = matmul(...) / sqrt(d_k)
98+
scores = mx.matmul(query, mx.transpose(key, axes=(0, 2, 1))) / math.sqrt(d_model)
99+
else:
100+
# For other strategies, compute separately
101+
scores = mx.matmul(query, mx.transpose(key, axes=(0, 2, 1)))
102+
scores = scores * scale
100103

101-
# Apply mask if provided
104+
# Apply mask if provided - match reference implementation
102105
if mask is not None:
103-
# Ensure mask has the right shape and dtype
104-
if mask.shape != scores.shape:
105-
# Handle different mask shapes - broadcast if needed
106-
if len(mask.shape) == 2: # [seq_len, seq_len]
107-
mask = mx.broadcast_to(mask[None, :, :], scores.shape)
108-
elif len(mask.shape) == 3 and mask.shape[0] == 1: # [1, seq_len, seq_len]
109-
mask = mx.broadcast_to(mask, scores.shape)
110-
111-
mask_value = -1e9 if compute_dtype == mx.float32 else -1e4
112-
scores = scores + mask * mask_value
106+
# Reference implementation does: scores = scores + mask
107+
# So mask should already contain the large negative values
108+
scores = scores + mask
113109

114110
# Compute attention weights (always use high precision initially)
115111
attention_weights = mx.softmax(scores, axis=-1)
@@ -138,18 +134,13 @@ def _chunked_attention(
138134
"""
139135
# For now, fall back to standard attention to ensure correctness
140136
# Evolution will implement proper chunking
141-
scores = mx.matmul(query, mx.transpose(key, axes=(0, 2, 1)))
142-
scores = scores * scale
137+
d_model = query.shape[-1]
138+
139+
# Match reference implementation exactly
140+
scores = mx.matmul(query, mx.transpose(key, axes=(0, 2, 1))) / math.sqrt(d_model)
143141

144142
if mask is not None:
145-
if mask.shape != scores.shape:
146-
if len(mask.shape) == 2: # [seq_len, seq_len]
147-
mask = mx.broadcast_to(mask[None, :, :], scores.shape)
148-
elif len(mask.shape) == 3 and mask.shape[0] == 1: # [1, seq_len, seq_len]
149-
mask = mx.broadcast_to(mask, scores.shape)
150-
151-
mask_value = -1e9 if scores.dtype == mx.float32 else -1e4
152-
scores = scores + mask * mask_value
143+
scores = scores + mask
153144

154145
attention_weights = mx.softmax(scores, axis=-1)
155146
output = mx.matmul(attention_weights, value)
@@ -230,7 +221,7 @@ def benchmark_attention(
230221

231222
# Create causal mask for decoder attention
232223
mask = mx.triu(mx.ones((seq_len, seq_len)), k=1) * -1e9
233-
mask = mx.broadcast_to(mask[None, None, :, :], (batch_size, 1, seq_len, seq_len))
224+
mask = mx.broadcast_to(mask[None, :, :], (batch_size, seq_len, seq_len))
234225

235226
# Warmup
236227
_ = optimized_attention_kernel(query, key, value, mask, **config)

0 commit comments

Comments
 (0)