Skip to content

feat: Add JAX acceleration support to z_n_search#966

Open
temp-noob wants to merge 1 commit intoStingraySoftware:mainfrom
temp-noob:useJaxFor_z_n_search
Open

feat: Add JAX acceleration support to z_n_search#966
temp-noob wants to merge 1 commit intoStingraySoftware:mainfrom
temp-noob:useJaxFor_z_n_search

Conversation

@temp-noob
Copy link

feat: Add JAX acceleration support to z_n_search

Adds optional JAX-accelerated backend for z_n_search via use_jax parameter.
The JAX implementation computes exact unbinned Z^2_n statistics directly from
event phases, complementing the existing numba-JIT'd binned approach.

Implementation details:

  • JAX kernel using vmap for vectorized frequency/fdot grid search
  • Conditional @jax.jit with static_argnums for nharm parameter
  • Graceful fallback when JAX not installed (HAS_JAX flag)
  • Compatible with existing API (1D and 2D search with fdots)

Added comprehensive test suite (16 tests):

  • Correctness tests validating peak detection
  • Comparison tests (JAX vs standard: >0.95 correlation)
  • Benchmark tests showing ~19x speedup on 1D searches (CPU)

Performance on CPU: 1D (no fdot vector) searches gain ~19x speedup. Larger gains expected
on GPU-enabled JAX backends.

@temp-noob
Copy link
Author

@matteobachetti, this is my first time contributing to the repo. Let me know if I can help out with more things or things which would be more crucial.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant