Skip to content

profile: JAX call overhead with and without jax.jit wrapping #25

@hunter-heidenreich

Description

@hunter-heidenreich

Goal

Quantify the per-call tracing overhead in the JAX backend when called without jax.jit, and determine at what call count the amortized JIT cost breaks even.

Motivation

JAX re-traces Python functions on every call in eager mode. The JAX kabsch and horn functions include:

  • A custom VJP (safe_svd with @jax.custom_vjp)
  • Multiple jnp operations with shape-dependent branches (e.g. is_single, batch_dims)
  • A Python-level conditional in the VJP backward (vmap_diag = jax.vmap(...) if S.ndim > 1 else ...)

None of the public functions carry @jax.jit. Users calling these in a training loop without explicit JIT will pay trace overhead on every step. The magnitude of this overhead is uncharacterized.

Experimental Design

For kabsch and horn, measure:

  1. Eager (no JIT): wall time per call for B=1 and B=64, repeated 1000 times -- includes trace cost each time
  2. JIT-compiled: jit_kabsch = jax.jit(kabsch), wall time per call after first (warmup) call
  3. JIT compile time: time for the first JIT-traced call (compilation cost)

Compute:

  • Eager overhead per call vs. JIT-compiled per call (ratio)
  • Break-even N: how many calls until JIT pays off relative to eager

Expected Deliverables

  • Table: eager vs. JIT time per call for key (B, N) combinations
  • Break-even call count for typical use cases
  • Documentation update recommending jax.jit wrapping and showing the pattern
  • Assessment of whether pre-jitted convenience exports (e.g. kabsch_jit) would be valuable

Metadata

Metadata

Assignees

No one assigned

    Labels

    benchmarkPerformance measurement or profilingframework:jaxJAX-specific issuemoderateModerate impact, fix when possibleperformanceRuntime performance improvement

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions