-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Description
Summary
When a compiled function with shapeless=True performs a reduction (sum, mean) on the output of mx.take (GatherAxis), the reduction returns stale results from the first call. The take primitive correctly produces dynamic-sized output, but Reduce/ReduceAll replays the cached first-call result instead of re-executing.
Related: ml-explore/mlx-c#104 — same bug reproduced via the C API, confirming it's in mlx core.
MLX Version
0.31.0
Minimal Reproducer
import mlx.core as mx
def _fn(buf, idx):
taken = mx.take(buf, idx, axis=0)
return taken.sum()
fn = mx.compile(_fn, shapeless=True)
buf = mx.array([10.0, 20.0, 30.0, 40.0, 50.0])
mx.eval(buf)
for n in [1, 2, 3, 4]:
idx = mx.arange(n)
s = fn(buf, idx)
mx.eval(s)
expected = sum(buf.tolist()[:n])
print(f"n={n}: sum={s.item():.0f} expected={expected:.0f} "
f"{'PASS' if abs(s.item() - expected) < 0.01 else 'FAIL'}")Output
n=1: sum=10 expected=10 PASS
n=2: sum=10 expected=30 FAIL <-- should be 30
n=3: sum=10 expected=60 FAIL <-- should be 60
n=4: sum=10 expected=100 FAIL <-- should be 100
Diagnostic Results
| Variant | shapeless=True | Result |
|---|---|---|
take only (no reduction) |
PASS | GatherAxis correctly returns dynamic-sized output |
take → matmul |
PASS | Matmul on dynamic-sized take output works |
take → sum |
FAIL | Sum always returns first-call value (10) |
take → mean |
FAIL | Returns 10/n — sum part stale, count uses new size |
take → reshape(-1) → sum |
FAIL | Reshape doesn't help |
take → sum with shapeless=False |
PASS | Re-traces on shape change |
take → sum without compile |
PASS | No compilation, works correctly |
Root Cause Hypothesis
The mean result (10/n) is the smoking gun: MLX implements mean as sum / count. The count correctly uses the new input size, but the sum returns the stale traced value.
In compile_replace() with shapeless=True, Primitive::output_shapes() is called to infer output shapes for replayed nodes. For Reduce with reduce-all, the output is always scalar [], so the shape doesn't change between calls. The reduction kernel appears to either:
- Not be re-dispatched when input shapes change (since output shape is constant), or
- Use a cached intermediate buffer sized for the first-call input
Since the output shape is always [] regardless of input size, the shapeless machinery may be treating the output as unchanged and reusing the previous buffer without re-executing the kernel.
Impact
This blocks using shapeless=True for compiled functions that include reductions over variable-length inputs — a common pattern in LLM inference (e.g., attention over growing KV cache sequences). We discovered this while trying to use mlx_compile via the C API to fuse transformer decode sub-blocks for performance.