@@ -32,6 +32,8 @@ prompt:
3232 ❌ `grads.astype()` when grads is a dict - Only works on mx.array
3333 ❌ Any JAX/PyTorch tree utilities - MLX doesn't have these
3434 ❌ `mlx.utils.tree_*` functions - These don't exist
35+ ❌ `mx.value_and_grad(fn, has_aux=True)` - has_aux parameter does NOT exist in MLX
36+ ❌ `mx.value_and_grad(fn, **kwargs)` - No keyword arguments supported except argnums/argnames
3537 ❌ Assuming `mx.eval()` always returns arrays - Can return None
3638 ❌ Modulo operations without checking for zero divisors
3739 ❌ Assuming trainer attributes exist without checking
@@ -68,6 +70,35 @@ prompt:
6870 return grads
6971 ```
7072
73+ ✅ **Value and Grad Operations:**
74+ ```python
75+ # CORRECT: Simple value_and_grad usage
76+ loss_value, grads = mx.value_and_grad(loss_fn)(model)
77+
78+ # CORRECT: If you need multiple return values from loss_fn, handle separately
79+ def loss_fn(model):
80+ logits = model(inputs)
81+ loss = nn.losses.cross_entropy(logits, targets)
82+ # Return only the loss (not a tuple with aux data)
83+ return loss
84+
85+ loss_value, grads = mx.value_and_grad(loss_fn)(model)
86+
87+ # WRONG: mx.value_and_grad(loss_fn, has_aux=True)(model) # has_aux not supported
88+ # WRONG: (loss, aux), grads = mx.value_and_grad(loss_fn, has_aux=True)(model)
89+
90+ # CORRECT: If you need auxiliary data, compute it separately
91+ def loss_fn(model):
92+ logits = model(inputs)
93+ loss = nn.losses.cross_entropy(logits, targets)
94+ return loss
95+
96+ loss_value, grads = mx.value_and_grad(loss_fn)(model)
97+ # Compute auxiliary data separately if needed
98+ logits = model(inputs) # Recompute for aux data
99+ accuracy = compute_accuracy(logits, targets)
100+ ```
101+
71102 ✅ **Memory Management:**
72103 ```python
73104 # Use mx.eval() to materialize computations
@@ -150,6 +181,8 @@ prompt:
150181 8. ✓ Check object attributes exist before accessing
151182 9. ✓ Handle None and empty arrays gracefully
152183 10. ✓ Use safe fallbacks for all operations
184+ 11. ✓ mx.value_and_grad() used without has_aux parameter
185+ 12. ✓ Loss functions return single values, not tuples
153186
154187 **PRIMARY GOAL: Discover memory-efficient patterns that enable faster, lower-memory fine-tuning on Mac hardware**
155188
@@ -204,6 +237,24 @@ prompt:
204237 actual_memory = process.memory_info().rss / 1024 / 1024
205238 ```
206239
240+ ❌ **value_and_grad() incompatible function arguments**
241+ ```python
242+ # WRONG: Using JAX-style has_aux parameter
243+ (scaled_loss_val, unscaled_loss_val), grads = mx.value_and_grad(loss_fn, has_aux=True)(model)
244+
245+ # RIGHT: MLX only supports simple value_and_grad
246+ loss_value, grads = mx.value_and_grad(loss_fn)(model)
247+
248+ # If you need scaled loss, handle it in the loss function itself:
249+ def loss_fn(model):
250+ logits = model(inputs)
251+ loss = nn.losses.cross_entropy(logits, targets)
252+ # Scale inside the function if needed
253+ return loss / max(total_accumulation_steps, 1)
254+
255+ loss_value, grads = mx.value_and_grad(loss_fn)(model)
256+ ```
257+
207258 ❌ **'NoneType' object is not subscriptable**
208259 ```python
209260 # WRONG: loss_value = mx.eval(loss)[0] # mx.eval() might return None
0 commit comments