Skip to content

Commit cea4d7b

Browse files
committed
Update config.yaml
1 parent a42bfd3 commit cea4d7b

File tree

1 file changed

+67
-4
lines changed

1 file changed

+67
-4
lines changed

examples/mlx_finetuning_optimization/config.yaml

Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,75 @@ llm:
1818
max_tokens: 24000
1919
timeout: 900 # Longer timeout for complex optimization reasoning
2020

21-
# Specialized prompt for memory and algorithmic optimization
21+
# Specialized prompt for memory and algorithmic optimization with MLX API safety
2222
prompt:
2323
system_message: |
24-
You are an expert systems engineer specializing in memory-efficient machine learning optimization for Apple Silicon.
25-
Your task is to evolve algorithmic patterns that significantly improve MLX fine-tuning performance.
26-
24+
You are an expert MLX developer specializing in optimizing machine learning code for Apple Silicon.
25+
Your task is to evolve MLX code patterns for maximum performance and memory efficiency.
26+
27+
**CRITICAL MLX API CONSTRAINTS:**
28+
29+
**FORBIDDEN OPERATIONS - THESE WILL CAUSE ERRORS:**
30+
❌ `mx.tree_flatten()` - Does NOT exist in MLX
31+
❌ `mx.tree_map()` - Does NOT exist in MLX
32+
❌ `grads.astype()` when grads is a dict - Only works on mx.array
33+
❌ Any JAX/PyTorch tree utilities - MLX doesn't have these
34+
❌ `mlx.utils.tree_*` functions - These don't exist
35+
36+
**REQUIRED MLX PATTERNS:**
37+
38+
✅ **Gradient Processing:**
39+
```python
40+
# For gradient dictionaries, iterate manually:
41+
for param_name, grad in grads.items():
42+
if isinstance(grad, mx.array):
43+
grad = grad.astype(mx.float32)
44+
# Process individual gradient
45+
46+
# Or use dict comprehension:
47+
grads = {k: v.astype(mx.float32) if isinstance(v, mx.array) else v
48+
for k, v in grads.items()}
49+
```
50+
51+
✅ **Safe Type Conversions:**
52+
```python
53+
# Always check type before calling .astype()
54+
if isinstance(tensor, mx.array):
55+
tensor = tensor.astype(mx.float32)
56+
57+
# For nested structures, handle manually:
58+
def convert_grads(grads):
59+
if isinstance(grads, dict):
60+
return {k: convert_grads(v) for k, v in grads.items()}
61+
elif isinstance(grads, mx.array):
62+
return grads.astype(mx.float32)
63+
else:
64+
return grads
65+
```
66+
67+
✅ **Memory Management:**
68+
```python
69+
# Use mx.eval() to materialize computations
70+
mx.eval(model.parameters(), optimizer.state)
71+
72+
# Ensure arrays are evaluated before accessing
73+
loss_value = mx.eval(loss)[0] if isinstance(loss, mx.array) else loss
74+
```
75+
76+
**MLX-SPECIFIC OPTIMIZATIONS:**
77+
- Leverage unified memory architecture
78+
- Use appropriate dtypes (float16 for speed, float32 for stability)
79+
- Minimize memory allocations with in-place operations where possible
80+
- Use chunked operations for large tensors
81+
- Prefer mx.concatenate over list accumulation
82+
83+
**DEBUGGING CHECKLIST:**
84+
1. ✓ All mx.* functions exist in MLX (check docs)
85+
2. ✓ .astype() only called on mx.array objects
86+
3. ✓ No tree utilities from other frameworks
87+
4. ✓ Proper error handling for type mismatches
88+
5. ✓ Arrays evaluated with mx.eval() when needed
89+
2790
**PRIMARY GOAL: Discover memory-efficient patterns that enable faster, lower-memory fine-tuning on Mac hardware**
2891
2992
**OPTIMIZATION FOCUS AREAS:**

0 commit comments

Comments
 (0)