Skip to content

[BUG] compile(shapeless=True): Reduce returns stale values on dynamically-shaped inputs from take/gather #3201

@ghstrider

Description

@ghstrider

Summary

When a compiled function with shapeless=True performs a reduction (sum, mean) on the output of mx.take (GatherAxis), the reduction returns stale results from the first call. The take primitive correctly produces dynamic-sized output, but Reduce/ReduceAll replays the cached first-call result instead of re-executing.

Related: ml-explore/mlx-c#104 — same bug reproduced via the C API, confirming it's in mlx core.

MLX Version

0.31.0

Minimal Reproducer

import mlx.core as mx

def _fn(buf, idx):
    taken = mx.take(buf, idx, axis=0)
    return taken.sum()

fn = mx.compile(_fn, shapeless=True)

buf = mx.array([10.0, 20.0, 30.0, 40.0, 50.0])
mx.eval(buf)

for n in [1, 2, 3, 4]:
    idx = mx.arange(n)
    s = fn(buf, idx)
    mx.eval(s)
    expected = sum(buf.tolist()[:n])
    print(f"n={n}: sum={s.item():.0f} expected={expected:.0f} "
          f"{'PASS' if abs(s.item() - expected) < 0.01 else 'FAIL'}")

Output

n=1: sum=10 expected=10 PASS
n=2: sum=10 expected=30 FAIL    <-- should be 30
n=3: sum=10 expected=60 FAIL    <-- should be 60
n=4: sum=10 expected=100 FAIL   <-- should be 100

Diagnostic Results

Variant shapeless=True Result
take only (no reduction) PASS GatherAxis correctly returns dynamic-sized output
takematmul PASS Matmul on dynamic-sized take output works
takesum FAIL Sum always returns first-call value (10)
takemean FAIL Returns 10/n — sum part stale, count uses new size
takereshape(-1)sum FAIL Reshape doesn't help
takesum with shapeless=False PASS Re-traces on shape change
takesum without compile PASS No compilation, works correctly

Root Cause Hypothesis

The mean result (10/n) is the smoking gun: MLX implements mean as sum / count. The count correctly uses the new input size, but the sum returns the stale traced value.

In compile_replace() with shapeless=True, Primitive::output_shapes() is called to infer output shapes for replayed nodes. For Reduce with reduce-all, the output is always scalar [], so the shape doesn't change between calls. The reduction kernel appears to either:

  1. Not be re-dispatched when input shapes change (since output shape is constant), or
  2. Use a cached intermediate buffer sized for the first-call input

Since the output shape is always [] regardless of input size, the shapeless machinery may be treating the output as unchanged and reusing the previous buffer without re-executing the kernel.

Impact

This blocks using shapeless=True for compiled functions that include reductions over variable-length inputs — a common pattern in LLM inference (e.g., attention over growing KV cache sequences). We discovered this while trying to use mlx_compile via the C API to fuse transformer decode sub-blocks for performance.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions