Skip to content

Commit 9f99f63

Browse files
committed
f
1 parent 6e321b5 commit 9f99f63

File tree

9 files changed

+3483
-0
lines changed

9 files changed

+3483
-0
lines changed

examples/mlx_attention_optimization/attention_benchmark.py

Lines changed: 1144 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 392 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,392 @@
1+
"""
2+
MLX Attention Integration Helper
3+
4+
This module provides utilities to easily integrate OpenEvolve-optimized attention
5+
into existing MLX models for side-by-side comparison and deployment.
6+
7+
Key features:
8+
- Load any MLX model with optimized attention
9+
- Compare standard vs optimized attention performance
10+
- Minimal code changes required (2-3 lines)
11+
- Support for popular models (Qwen, Llama, etc.)
12+
"""
13+
14+
import importlib.util
15+
import os
16+
import time
17+
from typing import Dict, Optional, Tuple, Any
18+
19+
import mlx.core as mx
20+
import mlx.nn as nn
21+
22+
try:
23+
import mlx_lm
24+
from mlx_lm import load, generate
25+
MLX_LM_AVAILABLE = True
26+
except ImportError:
27+
print("⚠️ mlx_lm not available. Real model integration will be limited.")
28+
MLX_LM_AVAILABLE = False
29+
30+
31+
class OptimizedAttentionWrapper:
32+
"""Wrapper to replace standard attention with optimized version"""
33+
34+
def __init__(self, evolved_program_path: str):
35+
"""
36+
Initialize with path to evolved attention program
37+
38+
Args:
39+
evolved_program_path: Path to the best_program.py from OpenEvolve
40+
"""
41+
self.evolved_program_path = evolved_program_path
42+
self.evolved_module = None
43+
self._load_evolved_module()
44+
45+
def _load_evolved_module(self):
46+
"""Load the evolved attention module"""
47+
if not os.path.exists(self.evolved_program_path):
48+
raise FileNotFoundError(f"Evolved program not found: {self.evolved_program_path}")
49+
50+
spec = importlib.util.spec_from_file_location("evolved_attention", self.evolved_program_path)
51+
self.evolved_module = importlib.util.module_from_spec(spec)
52+
spec.loader.exec_module(self.evolved_module)
53+
54+
if not hasattr(self.evolved_module, 'create_test_attention_module'):
55+
raise AttributeError("Evolved module missing create_test_attention_module function")
56+
57+
def create_optimized_attention(self, hidden_size: int, num_heads: int,
58+
num_kv_heads: int = None, **kwargs) -> nn.Module:
59+
"""
60+
Create optimized attention module compatible with model architecture
61+
62+
Args:
63+
hidden_size: Model hidden dimension
64+
num_heads: Number of attention heads
65+
num_kv_heads: Number of key-value heads (for GQA)
66+
**kwargs: Additional parameters (window_size, query_chunk_size, etc.)
67+
68+
Returns:
69+
Optimized attention module
70+
"""
71+
if num_kv_heads is None:
72+
num_kv_heads = num_heads
73+
74+
head_dim = hidden_size // num_heads
75+
76+
# Set reasonable defaults for optimization parameters
77+
default_kwargs = {
78+
'window_size': 64, # Enable windowed attention
79+
'query_chunk_size': 64, # Enable chunking
80+
'dilation_rate': 1 # No dilation by default
81+
}
82+
default_kwargs.update(kwargs)
83+
84+
try:
85+
return self.evolved_module.create_test_attention_module(
86+
hidden_size=hidden_size,
87+
num_heads=num_heads,
88+
num_kv_heads=num_kv_heads,
89+
head_dim=head_dim,
90+
**default_kwargs
91+
)
92+
except TypeError:
93+
# Fallback for evolved modules without new parameters
94+
return self.evolved_module.create_test_attention_module(
95+
hidden_size=hidden_size,
96+
num_heads=num_heads,
97+
num_kv_heads=num_kv_heads,
98+
head_dim=head_dim
99+
)
100+
101+
102+
def load_and_patch_model(model_path: str, evolved_program_path: str,
103+
patch_attention: bool = True) -> Tuple[Any, Any]:
104+
"""
105+
Load a model and optionally patch it with optimized attention
106+
107+
Args:
108+
model_path: Path to MLX model
109+
evolved_program_path: Path to evolved attention program
110+
patch_attention: Whether to patch attention layers
111+
112+
Returns:
113+
Tuple of (model, tokenizer)
114+
"""
115+
if not MLX_LM_AVAILABLE:
116+
raise ImportError("mlx_lm required for model loading")
117+
118+
print(f"📥 Loading model: {model_path}")
119+
model, tokenizer = load(model_path)
120+
121+
if patch_attention:
122+
print(f"🔧 Patching with optimized attention: {evolved_program_path}")
123+
wrapper = OptimizedAttentionWrapper(evolved_program_path)
124+
125+
# Try to detect and patch attention layers
126+
# This is model-specific and may need adjustment for different architectures
127+
patched_count = _patch_model_attention(model, wrapper)
128+
print(f"✅ Patched {patched_count} attention layers")
129+
130+
return model, tokenizer
131+
132+
133+
def _patch_model_attention(model: nn.Module, wrapper: OptimizedAttentionWrapper) -> int:
134+
"""
135+
Attempt to patch attention layers in a model
136+
This is a heuristic approach that works for common architectures
137+
138+
Args:
139+
model: MLX model to patch
140+
wrapper: Optimized attention wrapper
141+
142+
Returns:
143+
Number of layers patched
144+
"""
145+
patched_count = 0
146+
147+
# Common patterns for attention layer names
148+
attention_patterns = [
149+
'self_attn', 'attention', 'attn', 'multi_head_attention'
150+
]
151+
152+
def _recursive_patch(module, name_prefix=""):
153+
nonlocal patched_count
154+
155+
for name, child in module.__dict__.items():
156+
if isinstance(child, nn.Module):
157+
full_name = f"{name_prefix}.{name}" if name_prefix else name
158+
159+
# Check if this is an attention layer
160+
if any(pattern in name.lower() for pattern in attention_patterns):
161+
try:
162+
# Try to extract architecture details
163+
if hasattr(child, 'hidden_size') and hasattr(child, 'num_heads'):
164+
hidden_size = child.hidden_size
165+
num_heads = child.num_heads
166+
num_kv_heads = getattr(child, 'num_kv_heads', num_heads)
167+
168+
# Create optimized replacement
169+
optimized_attn = wrapper.create_optimized_attention(
170+
hidden_size=hidden_size,
171+
num_heads=num_heads,
172+
num_kv_heads=num_kv_heads
173+
)
174+
175+
# Replace the attention layer
176+
setattr(module, name, optimized_attn)
177+
patched_count += 1
178+
print(f" Patched: {full_name}")
179+
180+
except Exception as e:
181+
print(f" ⚠️ Failed to patch {full_name}: {str(e)}")
182+
183+
# Recursively check children
184+
_recursive_patch(child, full_name)
185+
186+
_recursive_patch(model)
187+
return patched_count
188+
189+
190+
def compare_attention_performance(model_path: str, evolved_program_path: str,
191+
prompt: str = "Write a Python function that",
192+
max_tokens: int = 100, runs: int = 3) -> Dict[str, Any]:
193+
"""
194+
Compare performance between standard and optimized attention
195+
196+
Args:
197+
model_path: Path to MLX model
198+
evolved_program_path: Path to evolved attention program
199+
prompt: Test prompt for generation
200+
max_tokens: Maximum tokens to generate
201+
runs: Number of benchmark runs
202+
203+
Returns:
204+
Performance comparison results
205+
"""
206+
207+
if not MLX_LM_AVAILABLE:
208+
raise ImportError("mlx_lm required for performance comparison")
209+
210+
print(f"⚖️ Comparing attention performance...")
211+
print(f" Model: {model_path}")
212+
print(f" Prompt: '{prompt}'")
213+
print(f" Max tokens: {max_tokens}")
214+
215+
results = {
216+
"model_path": model_path,
217+
"prompt": prompt,
218+
"max_tokens": max_tokens,
219+
"runs": runs
220+
}
221+
222+
# Test standard attention
223+
print(f"\n📊 Testing standard attention...")
224+
standard_model, tokenizer = load(model_path)
225+
standard_times = []
226+
227+
for run in range(runs):
228+
start_time = time.time()
229+
try:
230+
response = generate(standard_model, tokenizer, prompt,
231+
max_tokens=max_tokens, verbose=False)
232+
end_time = time.time()
233+
234+
run_time = end_time - start_time
235+
standard_times.append(run_time)
236+
237+
tokens_generated = len(response.split()) - len(prompt.split())
238+
tokens_per_sec = tokens_generated / run_time if run_time > 0 else 0
239+
240+
print(f" Run {run+1}: {run_time:.2f}s ({tokens_per_sec:.1f} tokens/sec)")
241+
242+
except Exception as e:
243+
print(f" Run {run+1} failed: {str(e)}")
244+
standard_times.append(float('inf'))
245+
246+
# Test optimized attention
247+
print(f"\n🚀 Testing optimized attention...")
248+
optimized_model, tokenizer = load_and_patch_model(model_path, evolved_program_path)
249+
optimized_times = []
250+
251+
for run in range(runs):
252+
start_time = time.time()
253+
try:
254+
response = generate(optimized_model, tokenizer, prompt,
255+
max_tokens=max_tokens, verbose=False)
256+
end_time = time.time()
257+
258+
run_time = end_time - start_time
259+
optimized_times.append(run_time)
260+
261+
tokens_generated = len(response.split()) - len(prompt.split())
262+
tokens_per_sec = tokens_generated / run_time if run_time > 0 else 0
263+
264+
print(f" Run {run+1}: {run_time:.2f}s ({tokens_per_sec:.1f} tokens/sec)")
265+
266+
except Exception as e:
267+
print(f" Run {run+1} failed: {str(e)}")
268+
optimized_times.append(float('inf'))
269+
270+
# Calculate comparison
271+
valid_standard = [t for t in standard_times if t < float('inf')]
272+
valid_optimized = [t for t in optimized_times if t < float('inf')]
273+
274+
if valid_standard and valid_optimized:
275+
avg_standard = sum(valid_standard) / len(valid_standard)
276+
avg_optimized = sum(valid_optimized) / len(valid_optimized)
277+
speedup = avg_standard / avg_optimized if avg_optimized > 0 else 0
278+
279+
results.update({
280+
"standard_avg_time": avg_standard,
281+
"optimized_avg_time": avg_optimized,
282+
"speedup": speedup,
283+
"standard_successful_runs": len(valid_standard),
284+
"optimized_successful_runs": len(valid_optimized),
285+
"improvement": "Yes" if speedup > 1.05 else "Minimal" if speedup > 1.0 else "No"
286+
})
287+
288+
print(f"\n📈 RESULTS:")
289+
print(f" Standard attention: {avg_standard:.2f}s average")
290+
print(f" Optimized attention: {avg_optimized:.2f}s average")
291+
print(f" Speedup: {speedup:.2f}x")
292+
print(f" Improvement: {results['improvement']}")
293+
294+
else:
295+
results["error"] = "Insufficient successful runs for comparison"
296+
print(f"\n❌ Comparison failed: insufficient successful runs")
297+
298+
return results
299+
300+
301+
def quick_demo(evolved_program_path: str,
302+
model_path: str = "mlx-community/Qwen3-0.6B-bf16"):
303+
"""
304+
Quick demonstration of optimized attention
305+
306+
Args:
307+
evolved_program_path: Path to evolved attention program
308+
model_path: Model to test with
309+
"""
310+
311+
print("🚀 OpenEvolve Optimized Attention Demo")
312+
print("=" * 50)
313+
314+
try:
315+
# Load model with optimized attention
316+
print(f"\n1️⃣ Loading model with optimized attention...")
317+
model, tokenizer = load_and_patch_model(model_path, evolved_program_path)
318+
319+
# Test prompts
320+
test_prompts = [
321+
"Write a Python function that calculates fibonacci numbers:",
322+
"Explain machine learning in simple terms:",
323+
"Create a haiku about programming:"
324+
]
325+
326+
print(f"\n2️⃣ Testing text generation...")
327+
for i, prompt in enumerate(test_prompts, 1):
328+
print(f"\n Test {i}: {prompt}")
329+
330+
start_time = time.time()
331+
response = generate(model, tokenizer, prompt, max_tokens=50, verbose=False)
332+
end_time = time.time()
333+
334+
generation_time = end_time - start_time
335+
tokens_generated = len(response.split()) - len(prompt.split())
336+
tokens_per_sec = tokens_generated / generation_time if generation_time > 0 else 0
337+
338+
print(f" Response: {response[len(prompt):].strip()}")
339+
print(f" Performance: {generation_time:.2f}s ({tokens_per_sec:.1f} tokens/sec)")
340+
341+
print(f"\n✅ Demo complete! The optimized attention is working.")
342+
print(f" Run the full benchmark for detailed performance comparisons.")
343+
344+
except Exception as e:
345+
print(f"\n❌ Demo failed: {str(e)}")
346+
raise
347+
348+
349+
def main():
350+
"""Command-line interface for attention integration"""
351+
352+
import argparse
353+
354+
parser = argparse.ArgumentParser(description="MLX Attention Integration Helper")
355+
356+
subparsers = parser.add_subparsers(dest='command', help='Available commands')
357+
358+
# Demo command
359+
demo_parser = subparsers.add_parser('demo', help='Quick demonstration')
360+
demo_parser.add_argument('--evolved-program', required=True,
361+
help='Path to evolved attention program')
362+
demo_parser.add_argument('--model', default='mlx-community/Qwen3-0.6B-bf16',
363+
help='Model to test with')
364+
365+
# Compare command
366+
compare_parser = subparsers.add_parser('compare', help='Compare standard vs optimized')
367+
compare_parser.add_argument('--evolved-program', required=True,
368+
help='Path to evolved attention program')
369+
compare_parser.add_argument('--model', default='mlx-community/Qwen3-0.6B-bf16',
370+
help='Model to test with')
371+
compare_parser.add_argument('--prompt', default='Write a Python function that',
372+
help='Test prompt')
373+
compare_parser.add_argument('--max-tokens', type=int, default=100,
374+
help='Maximum tokens to generate')
375+
compare_parser.add_argument('--runs', type=int, default=3,
376+
help='Number of benchmark runs')
377+
378+
args = parser.parse_args()
379+
380+
if args.command == 'demo':
381+
quick_demo(args.evolved_program, args.model)
382+
elif args.command == 'compare':
383+
compare_attention_performance(
384+
args.model, args.evolved_program,
385+
args.prompt, args.max_tokens, args.runs
386+
)
387+
else:
388+
parser.print_help()
389+
390+
391+
if __name__ == "__main__":
392+
main()

0 commit comments

Comments
 (0)