@@ -32,6 +32,45 @@ prompt:
3232 ❌ `grads.astype()` when grads is a dict - Only works on mx.array
3333 ❌ Any JAX/PyTorch tree utilities - MLX doesn't have these
3434 ❌ `mlx.utils.tree_*` functions - These don't exist
35+ ❌ `model.update_parameters()` - MLX models don't have this method
36+ ❌ `float(loss_tuple)` - Loss might be tuple, extract properly
37+ ❌ `batch[:, :-1]` on 1D arrays - Check array dimensions first
38+ ❌ Assuming tensor shapes without verification
39+
40+ **CRITICAL MLX VALUE AND SHAPE HANDLING:**
41+
42+ 🚨 **Loss Value Extraction:**
43+ ```python
44+ # WRONG: float(loss_value) when loss_value might be tuple
45+ # CORRECT: Handle MLX loss properly
46+ if isinstance(loss_value, tuple):
47+ loss_scalar = float(loss_value[0]) # Extract first element
48+ elif isinstance(loss_value, mx.array):
49+ loss_scalar = float(mx.eval(loss_value)) # Evaluate and convert
50+ else:
51+ loss_scalar = float(loss_value)
52+ ```
53+
54+ 🚨 **Array Indexing Safety:**
55+ ```python
56+ # WRONG: batch[:, :-1] without checking dimensions
57+ # CORRECT: Check shape before indexing
58+ if batch.ndim >= 2:
59+ inputs = batch[:, :-1]
60+ targets = batch[:, 1:]
61+ else:
62+ # Handle 1D case or reshape
63+ inputs = batch[:-1]
64+ targets = batch[1:]
65+ ```
66+
67+ 🚨 **Model Parameter Updates:**
68+ ```python
69+ # WRONG: model.update_parameters(new_params)
70+ # CORRECT: Use optimizer.update()
71+ optimizer.update(model, grads)
72+ mx.eval(model.parameters(), optimizer.state)
73+ ```
3574 ❌ `mx.value_and_grad(fn, has_aux=True)` - has_aux parameter does NOT exist in MLX
3675 ❌ `mx.value_and_grad(fn, **kwargs)` - No keyword arguments supported except argnums/argnames
3776 ❌ Assuming `mx.eval()` always returns arrays - Can return None
0 commit comments