-
Notifications
You must be signed in to change notification settings - Fork 62
Feature Proposal: Numba parallel permutation kernel and JIT caching in crand.py #424
Description
Currently, every new Python session recompiles the Numba JIT functions in crand.py from scratch, which takes around 80 seconds on a typical machine. On top of that, the inner computation loop in compute_chunk processes all observations one at a time, even though the work across observations is completely independent and could run in parallel.
This proposal is to add cache=True to the existing @njit decorators (vec_permutations and compute_chunk), which brings the cold start cost down to under 2 milliseconds after the first run. It also introduces a lightweight helper function called _wloc_offsets that pre-computes per-observation offsets into the flat CSR weights array using np.cumsum, removing the sequential accumulator that was blocking parallelism. A new compute_chunk_parallel function using @njit(parallel=True) and prange then distributes the per-observation work across all available CPU threads.
Benchmarked on a Ryzen 9 5900HX with 16 threads, this reduces the inner kernel time from 88.8ms to 19.3ms, a 4.59x speedup, with full end-to-end wall time for Moran_Local at n=2500 and 199 permutations sitting at a median of 28.7ms. At n=10000 it comes in at 127ms.
The changes are confined entirely to crand.py and pyproject.toml. No public API is touched, no other files change, and the n_jobs > 1 joblib path is completely unmodified. A proof-of-concept implementation with tests is available at samay2504:feat/numba-parallel-crand.