Skip to content

Commit 6ec75d7

Browse files
committed
Update config.yaml
1 parent 93336c1 commit 6ec75d7

File tree

1 file changed

+28
-7
lines changed

1 file changed

+28
-7
lines changed

examples/mlx_finetuning_optimization/config.yaml

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@ log_level: "INFO"
99
# LLM configuration optimized for algorithmic pattern evolution
1010
llm:
1111
primary_model: "gemini-2.5-flash-preview-05-20"
12-
primary_model_weight: 0.7
12+
primary_model_weight: 0.6
1313
secondary_model: "gemini-2.5-pro-preview-05-06"
14-
secondary_model_weight: 0.3
14+
secondary_model_weight: 0.4
1515
api_base: "https://generativelanguage.googleapis.com/v1beta/openai/"
16-
temperature: 0.8
16+
temperature: 0.7
1717
top_p: 0.95
18-
max_tokens: 24000
18+
max_tokens: 32000
1919
timeout: 900 # Longer timeout for complex optimization reasoning
2020

2121
# Specialized prompt for memory and algorithmic optimization with MLX API safety
@@ -280,18 +280,39 @@ prompt:
280280
```python
281281
# WRONG: Using JAX-style has_aux parameter
282282
(scaled_loss_val, unscaled_loss_val), grads = mx.value_and_grad(loss_fn, has_aux=True)(model)
283+
# This causes unscaled_loss_val to be a tuple! float(tuple) fails!
284+
285+
# WRONG: Multiple return values from loss function when using value_and_grad
286+
def loss_fn(model):
287+
logits = model(inputs)
288+
loss = nn.losses.cross_entropy(logits, targets)
289+
return loss, some_aux_data # WRONG! Creates tuple!
290+
291+
loss_tuple, grads = mx.value_and_grad(loss_fn)(model) # loss_tuple is (loss, aux_data)
292+
return float(loss_tuple) # ERROR: float() argument must be a real number, not 'tuple'
283293
284294
# RIGHT: MLX only supports simple value_and_grad
295+
def loss_fn(model):
296+
logits = model(inputs)
297+
loss = nn.losses.cross_entropy(logits, targets)
298+
return loss # Return ONLY the loss, not a tuple
299+
285300
loss_value, grads = mx.value_and_grad(loss_fn)(model)
301+
return float(loss_value), should_update # loss_value is now a scalar
286302
287-
# If you need scaled loss, handle it in the loss function itself:
303+
# RIGHT: If you need auxiliary data, compute it separately
288304
def loss_fn(model):
289305
logits = model(inputs)
290306
loss = nn.losses.cross_entropy(logits, targets)
291-
# Scale inside the function if needed
292-
return loss / max(total_accumulation_steps, 1)
307+
return loss # Only return loss for value_and_grad
293308
294309
loss_value, grads = mx.value_and_grad(loss_fn)(model)
310+
# Compute auxiliary data separately if needed
311+
with mx.no_grad(): # Don't need gradients for aux computation
312+
logits = model(inputs)
313+
accuracy = compute_accuracy(logits, targets)
314+
315+
return float(loss_value), should_update
295316
```
296317
297318
❌ **'NoneType' object is not subscriptable**

0 commit comments

Comments
 (0)