Skip to content

nn.Linear operand order causes unnecessary permutes under vmap + grad #1195

@CatchemAL

Description

@CatchemAL

eqx.nn.Linear computes weight @ x. Under vmap + jax.grad, this places the output dimension before the batch dimensions in intermediate activations, forcing XLA to insert 3D permutes to restore batch-leading layout at every layer boundary.

Switching to x @ weight.T (the Flax / PyTorch approach) keeps batch dimensions leading throughout and eliminates these permutes entirely.

Reproduction

3D case — 8-block residual MLP, DIM=64, BATCH=256, SEQ=32, double vmap:

Variant Time Speedup
w @ x (manual) 66.9 ms
x @ w.T (manual) 19.3 ms 3.5x
eqx.nn.Linear 67.7 ms

Full example with HLO analysis here.

2D case — same architecture, DIM=128, BATCH=1024, single vmap (no seq dimension):

Variant Time Speedup
w @ x (manual) 8.61 ms
x @ w.T (manual) 7.18 ms 1.2x
eqx.nn.Linear 8.29 ms

Full example here. Not big but not zero.

Why this happens

For a single vector, w @ x and x @ w.T are both a matvec. But under vmap, the batch dimension ends up in different positions:

w @ x    →  (out, in) @ (batch, seq, in) → (out, batch, seq)  ← needs permute back
x @ w.T  →  (batch, seq, in) @ (in, out) → (batch, seq, out)  ← already correct

Downstream ops (gelu, layernorm, residual add, the next linear) expect batch-leading layout, so a permute is needed. In the 2D case, XLA can often fuse away the layout fix (or keep it implicit), so w @ x may not be meaningfully worse. In the 3D nested-vmap case, XLA commonly must materialize a large activation transpose (out,batch,seq) ↔ (batch,seq,out) at layer boundaries.

TL;DR - we want nn.Linear to do x @ w.T as opposed to (w @ x.T).T because the latter relies on XLA to resolve this and it falls apart when x has 3 or more dimensions.

HLO evidence (3D case)

Transpose ops from jax.jit(jax.grad(fn)).lower(params).as_text():

── w @ x ──
  32× transpose 64x256x32xf32 ↔ 256x32x64xf32   ← expensive 3D permutes
  0 weight transposes, 32 activation permutes

── x @ w.T ──
  48× transpose 64x64xf32 → 64x64xf32           ← 2D transposes only
  48 weight transposes, 0 activation permutes

── eqx.nn.Linear ──
  32× transpose 64x256x32xf32 ↔ 256x32x64xf32   ← identical to w @ x
  0 weight transposes, 32 activation permutes

Suggested fix

In _linear.py, change:

x = self.weight @ x

to:

x = x @ self.weight.T

For unbatched input (1D), both are equivalent. Under vmap + grad, batch dimensions stay leading and no activation permute is needed.

Versions

  • equinox 0.13.4
  • jax 0.9.0
  • CPU (no GPU)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions