Skip to content

Commit 979b4a7

Browse files
committed
Update evaluator.py
1 parent 50bee94 commit 979b4a7

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

examples/mlx_kernel_optimization/evaluator.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -306,16 +306,9 @@ def loss_fn(model, inputs, targets):
306306
logits_flat = logits.reshape(-1, vocab_size)
307307
targets_flat = targets.reshape(-1)
308308

309-
# Mask padding tokens (assume 0 is pad token)
310-
mask = targets_flat != 0
311-
if mx.sum(mask) == 0: # All padding, use all tokens
312-
mask = mx.ones_like(targets_flat, dtype=mx.bool_())
313-
314-
# Apply mask
315-
logits_masked = logits_flat[mask]
316-
targets_masked = targets_flat[mask]
317-
318-
return nn.losses.cross_entropy(logits_masked, targets_masked, reduction='mean')
309+
# Simple cross-entropy without masking to avoid boolean indexing issues
310+
# MLX doesn't support boolean indexing, so we'll compute loss on all tokens
311+
return nn.losses.cross_entropy(logits_flat, targets_flat, reduction='mean')
319312

320313
# Gradient function
321314
value_and_grad_fn = mx.value_and_grad(loss_fn)
@@ -324,9 +317,17 @@ def loss_fn(model, inputs, targets):
324317
def get_memory_usage():
325318
# Simple memory estimation based on array sizes
326319
total_memory = 0
327-
for param in model.parameters():
328-
if hasattr(param, 'size'):
329-
total_memory += param.size * 4 # Assume 4 bytes per float
320+
try:
321+
for param in model.parameters():
322+
if hasattr(param, 'shape'):
323+
# Calculate memory usage: shape -> total elements -> bytes
324+
total_elements = 1
325+
for dim in param.shape:
326+
total_elements *= dim
327+
total_memory += total_elements * 4 # Assume 4 bytes per float32
328+
except Exception:
329+
# Fallback to simple estimation
330+
total_memory = 64 * 1024 * 1024 # 64MB default
330331
return total_memory / (1024 * 1024) # MB
331332

332333
initial_memory = get_memory_usage()

0 commit comments

Comments
 (0)