-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Labels
benchmarkPerformance measurement or profilingPerformance measurement or profilingframework:jaxJAX-specific issueJAX-specific issuemoderateModerate impact, fix when possibleModerate impact, fix when possibleperformanceRuntime performance improvementRuntime performance improvement
Description
Goal
Quantify the per-call tracing overhead in the JAX backend when called without jax.jit, and determine at what call count the amortized JIT cost breaks even.
Motivation
JAX re-traces Python functions on every call in eager mode. The JAX kabsch and horn functions include:
- A custom VJP (
safe_svdwith@jax.custom_vjp) - Multiple
jnpoperations with shape-dependent branches (e.g.is_single,batch_dims) - A Python-level conditional in the VJP backward (
vmap_diag = jax.vmap(...) if S.ndim > 1 else ...)
None of the public functions carry @jax.jit. Users calling these in a training loop without explicit JIT will pay trace overhead on every step. The magnitude of this overhead is uncharacterized.
Experimental Design
For kabsch and horn, measure:
- Eager (no JIT): wall time per call for B=1 and B=64, repeated 1000 times -- includes trace cost each time
- JIT-compiled:
jit_kabsch = jax.jit(kabsch), wall time per call after first (warmup) call - JIT compile time: time for the first JIT-traced call (compilation cost)
Compute:
- Eager overhead per call vs. JIT-compiled per call (ratio)
- Break-even N: how many calls until JIT pays off relative to eager
Expected Deliverables
- Table: eager vs. JIT time per call for key (B, N) combinations
- Break-even call count for typical use cases
- Documentation update recommending
jax.jitwrapping and showing the pattern - Assessment of whether pre-jitted convenience exports (e.g.
kabsch_jit) would be valuable
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
benchmarkPerformance measurement or profilingPerformance measurement or profilingframework:jaxJAX-specific issueJAX-specific issuemoderateModerate impact, fix when possibleModerate impact, fix when possibleperformanceRuntime performance improvementRuntime performance improvement