-
-
Notifications
You must be signed in to change notification settings - Fork 182
nn.Linear operand order causes unnecessary permutes under vmap + grad #1195
Description
eqx.nn.Linear computes weight @ x. Under vmap + jax.grad, this places the output dimension before the batch dimensions in intermediate activations, forcing XLA to insert 3D permutes to restore batch-leading layout at every layer boundary.
Switching to x @ weight.T (the Flax / PyTorch approach) keeps batch dimensions leading throughout and eliminates these permutes entirely.
Reproduction
3D case — 8-block residual MLP, DIM=64, BATCH=256, SEQ=32, double vmap:
| Variant | Time | Speedup |
|---|---|---|
w @ x (manual) |
66.9 ms | — |
x @ w.T (manual) |
19.3 ms | 3.5x |
eqx.nn.Linear |
67.7 ms | — |
Full example with HLO analysis here.
2D case — same architecture, DIM=128, BATCH=1024, single vmap (no seq dimension):
| Variant | Time | Speedup |
|---|---|---|
w @ x (manual) |
8.61 ms | — |
x @ w.T (manual) |
7.18 ms | 1.2x |
eqx.nn.Linear |
8.29 ms | — |
Full example here. Not big but not zero.
Why this happens
For a single vector, w @ x and x @ w.T are both a matvec. But under vmap, the batch dimension ends up in different positions:
w @ x → (out, in) @ (batch, seq, in) → (out, batch, seq) ← needs permute back
x @ w.T → (batch, seq, in) @ (in, out) → (batch, seq, out) ← already correct
Downstream ops (gelu, layernorm, residual add, the next linear) expect batch-leading layout, so a permute is needed. In the 2D case, XLA can often fuse away the layout fix (or keep it implicit), so w @ x may not be meaningfully worse. In the 3D nested-vmap case, XLA commonly must materialize a large activation transpose (out,batch,seq) ↔ (batch,seq,out) at layer boundaries.
TL;DR - we want nn.Linear to do x @ w.T as opposed to (w @ x.T).T because the latter relies on XLA to resolve this and it falls apart when x has 3 or more dimensions.
HLO evidence (3D case)
Transpose ops from jax.jit(jax.grad(fn)).lower(params).as_text():
── w @ x ──
32× transpose 64x256x32xf32 ↔ 256x32x64xf32 ← expensive 3D permutes
0 weight transposes, 32 activation permutes
── x @ w.T ──
48× transpose 64x64xf32 → 64x64xf32 ← 2D transposes only
48 weight transposes, 0 activation permutes
── eqx.nn.Linear ──
32× transpose 64x256x32xf32 ↔ 256x32x64xf32 ← identical to w @ x
0 weight transposes, 32 activation permutes
Suggested fix
In _linear.py, change:
x = self.weight @ xto:
x = x @ self.weight.TFor unbatched input (1D), both are equivalent. Under vmap + grad, batch dimensions stay leading and no activation permute is needed.
Versions
- equinox 0.13.4
- jax 0.9.0
- CPU (no GPU)