Skip to content

Commit bf73e00

Browse files
committed
fix
1 parent cea4d7b commit bf73e00

File tree

1 file changed

+104
-2
lines changed

1 file changed

+104
-2
lines changed

examples/mlx_finetuning_optimization/config.yaml

Lines changed: 104 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ prompt:
3232
❌ `grads.astype()` when grads is a dict - Only works on mx.array
3333
❌ Any JAX/PyTorch tree utilities - MLX doesn't have these
3434
❌ `mlx.utils.tree_*` functions - These don't exist
35+
❌ Assuming `mx.eval()` always returns arrays - Can return None
36+
❌ Modulo operations without checking for zero divisors
37+
❌ Assuming trainer attributes exist without checking
38+
❌ Accessing array indices without checking if array exists
3539
3640
**REQUIRED MLX PATTERNS:**
3741
@@ -69,8 +73,63 @@ prompt:
6973
# Use mx.eval() to materialize computations
7074
mx.eval(model.parameters(), optimizer.state)
7175
72-
# Ensure arrays are evaluated before accessing
73-
loss_value = mx.eval(loss)[0] if isinstance(loss, mx.array) else loss
76+
# SAFE: Check mx.eval() return values before indexing
77+
eval_result = mx.eval(loss)
78+
if eval_result is not None:
79+
loss_value = eval_result[0] if isinstance(eval_result, mx.array) else eval_result
80+
else:
81+
loss_value = float(loss) if hasattr(loss, '__float__') else 0.0
82+
83+
# SAFE: Alternative pattern for loss evaluation
84+
loss_value = float(loss) if isinstance(loss, (int, float)) else float(mx.eval(loss) or 0.0)
85+
```
86+
87+
✅ **Safe Arithmetic Operations:**
88+
```python
89+
# SAFE: Check for zero before modulo operations
90+
if total_accumulation_steps > 0 and (accumulation_step + 1) % total_accumulation_steps == 0:
91+
# Perform update
92+
pass
93+
94+
# SAFE: Division with fallback
95+
batch_size = len(batch) if batch is not None and len(batch) > 0 else 1
96+
normalized_loss = total_loss / max(batch_size, 1)
97+
```
98+
99+
✅ **Safe Attribute Access:**
100+
```python
101+
# SAFE: Check attributes before accessing
102+
if hasattr(trainer, 'accumulated_grads'):
103+
grads = trainer.accumulated_grads
104+
else:
105+
# Initialize if needed
106+
trainer.accumulated_grads = {}
107+
grads = trainer.accumulated_grads
108+
109+
# SAFE: Use getattr with defaults
110+
accumulated_grads = getattr(trainer, 'accumulated_grads', None)
111+
if accumulated_grads is None:
112+
accumulated_grads = {}
113+
setattr(trainer, 'accumulated_grads', accumulated_grads)
114+
```
115+
116+
✅ **Safe Array Operations:**
117+
```python
118+
# SAFE: Check array existence and shape before indexing
119+
if isinstance(tensor, mx.array) and tensor.size > 0:
120+
first_element = tensor[0]
121+
else:
122+
first_element = 0.0
123+
124+
# SAFE: Robust tensor evaluation
125+
def safe_eval(tensor):
126+
if tensor is None:
127+
return None
128+
try:
129+
result = mx.eval(tensor)
130+
return result if result is not None else tensor
131+
except Exception:
132+
return tensor
74133
```
75134
76135
**MLX-SPECIFIC OPTIMIZATIONS:**
@@ -86,9 +145,49 @@ prompt:
86145
3. ✓ No tree utilities from other frameworks
87146
4. ✓ Proper error handling for type mismatches
88147
5. ✓ Arrays evaluated with mx.eval() when needed
148+
6. ✓ Check mx.eval() return values before indexing
149+
7. ✓ Verify divisors are non-zero before modulo/division
150+
8. ✓ Check object attributes exist before accessing
151+
9. ✓ Handle None and empty arrays gracefully
152+
10. ✓ Use safe fallbacks for all operations
89153
90154
**PRIMARY GOAL: Discover memory-efficient patterns that enable faster, lower-memory fine-tuning on Mac hardware**
91155
156+
**COMMON RUNTIME ERROR PATTERNS TO AVOID:**
157+
158+
❌ **'NoneType' object is not subscriptable**
159+
```python
160+
# WRONG: loss_value = mx.eval(loss)[0] # mx.eval() might return None
161+
# RIGHT:
162+
eval_result = mx.eval(loss)
163+
loss_value = eval_result[0] if eval_result is not None else 0.0
164+
```
165+
166+
❌ **integer modulo by zero**
167+
```python
168+
# WRONG: if step % accumulation_steps == 0: # accumulation_steps might be 0
169+
# RIGHT:
170+
if accumulation_steps > 0 and step % accumulation_steps == 0:
171+
```
172+
173+
❌ **'object' has no attribute**
174+
```python
175+
# WRONG: trainer.accumulated_grads # attribute might not exist
176+
# RIGHT:
177+
if hasattr(trainer, 'accumulated_grads'):
178+
grads = trainer.accumulated_grads
179+
else:
180+
trainer.accumulated_grads = {}
181+
grads = trainer.accumulated_grads
182+
```
183+
184+
❌ **TypeError: unsupported operand type(s)**
185+
```python
186+
# WRONG: loss = loss1 + loss2 # types might be incompatible
187+
# RIGHT:
188+
loss = float(loss1) + float(loss2) if loss1 is not None and loss2 is not None else 0.0
189+
```
190+
92191
**OPTIMIZATION FOCUS AREAS:**
93192
94193
**Memory-Efficient Attention Patterns:**
@@ -179,6 +278,9 @@ prompt:
179278
- Balance memory savings with computational overhead
180279
- Maintain numerical stability and training quality
181280
- Consider Apple Silicon architecture specifics
281+
- **ALWAYS use defensive programming: check types, values, and attributes**
282+
- **NEVER assume function return values or object states**
283+
- **INCLUDE error handling and safe fallbacks in all operations**
182284
183285
**IMPLEMENTATION CONSTRAINTS:**
184286
- Must use MLX operations and data types

0 commit comments

Comments
 (0)