diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index ce3d1cd4a9..5d49eee1ba 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -88,8 +88,8 @@ class ComputeCorrelograms(AnalyzerExtension): use_nodepipeline = False need_job_kwargs = False - def _set_params(self, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto"): - params = dict(window_ms=window_ms, bin_ms=bin_ms, method=method) + def _set_params(self, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto", fast_mode: bool = False): + params = dict(window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode) return params @@ -215,6 +215,7 @@ def compute_correlograms( window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto", + fast_mode: bool = False, ): """ Compute correlograms using Numba or Numpy. @@ -225,11 +226,11 @@ def compute_correlograms( if isinstance(sorting_analyzer_or_sorting, SortingAnalyzer): return compute_correlograms_sorting_analyzer( - sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method + sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode ) else: return _compute_correlograms_on_sorting( - sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method + sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode ) @@ -299,7 +300,7 @@ def _compute_num_bins(window_size, bin_size): return num_bins, num_half_bins -def _compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto"): +def _compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto", fast_mode=False): """ Computes cross-correlograms from multiple units. @@ -318,6 +319,9 @@ def _compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto"): method : str To use "numpy" or "numba". "auto" will use numba if available, otherwise numpy. + fast_mode : bool + If True, use faster implementations (currently only if method is 'numba'), + at the cost of possible minor numerical differences. Returns ------- @@ -336,15 +340,15 @@ def _compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto"): bins, window_size, bin_size = _make_bins(sorting, window_ms, bin_ms) if method == "numpy": - correlograms = _compute_correlograms_numpy(sorting, window_size, bin_size) + correlograms = _compute_correlograms_numpy(sorting, window_size, bin_size, fast_mode=fast_mode) if method == "numba": - correlograms = _compute_correlograms_numba(sorting, window_size, bin_size) + correlograms = _compute_correlograms_numba(sorting, window_size, bin_size, fast_mode=fast_mode) return correlograms, bins # LOW-LEVEL IMPLEMENTATIONS -def _compute_correlograms_numpy(sorting, window_size, bin_size): +def _compute_correlograms_numpy(sorting, window_size, bin_size, fast_mode): """ Computes correlograms for all units in a sorting object. @@ -371,14 +375,14 @@ def _compute_correlograms_numpy(sorting, window_size, bin_size): spike_times = spikes[seg_index]["sample_index"] spike_unit_indices = spikes[seg_index]["unit_index"] - c0 = correlogram_for_one_segment(spike_times, spike_unit_indices, window_size, bin_size) + c0 = correlogram_for_one_segment(spike_times, spike_unit_indices, window_size, bin_size, fast_mode) correlograms += c0 return correlograms -def correlogram_for_one_segment(spike_times, spike_unit_indices, window_size, bin_size): +def correlogram_for_one_segment(spike_times, spike_unit_indices, window_size, bin_size, fast_mode=False): """ A very well optimized algorithm for the cross-correlation of spike trains, copied from the Phy package, written by Cyrille Rossant. @@ -395,6 +399,9 @@ def correlogram_for_one_segment(spike_times, spike_unit_indices, window_size, bi The window size over which to perform the cross-correlation, in samples bin_size : int The size of which to bin lags, in samples. + fast_mode : bool, default: False + If True, use faster implementations (currently only if method is 'numba'), + at the cost of possible minor numerical differences. Returns ------- @@ -483,7 +490,7 @@ def correlogram_for_one_segment(spike_times, spike_unit_indices, window_size, bi return correlograms -def _compute_correlograms_numba(sorting, window_size, bin_size): +def _compute_correlograms_numba(sorting, window_size, bin_size, fast_mode): """ Computes cross-correlograms between all units in `sorting`. @@ -499,6 +506,9 @@ def _compute_correlograms_numba(sorting, window_size, bin_size): The window size over which to perform the cross-correlation, in samples bin_size : int The size of which to bin lags, in samples. + fast_mode : bool + If True, use faster implementations (currently only if method is 'numba'), + at the cost of possible minor numerical differences. Returns ------- @@ -516,6 +526,11 @@ def _compute_correlograms_numba(sorting, window_size, bin_size): spikes = sorting.to_spike_vector(concatenated=False) correlograms = np.zeros((num_units, num_units, num_bins), dtype=np.int64) + if fast_mode: + num_threads = mp.cpu_count() + else: + num_threads = 1 + for seg_index in range(sorting.get_num_segments()): spike_times = spikes[seg_index]["sample_index"] spike_unit_indices = spikes[seg_index]["unit_index"] @@ -527,6 +542,7 @@ def _compute_correlograms_numba(sorting, window_size, bin_size): window_size, bin_size, num_half_bins, + num_threads, ) return correlograms @@ -539,9 +555,10 @@ def _compute_correlograms_numba(sorting, window_size, bin_size): nopython=True, nogil=True, cache=False, + parallel=True, ) def _compute_correlograms_one_segment_numba( - correlograms, spike_times, spike_unit_indices, window_size, bin_size, num_half_bins + correlograms, spike_times, spike_unit_indices, window_size, bin_size, num_half_bins, num_threads ): """ Compute the correlograms using `numba` for speed. @@ -570,9 +587,12 @@ def _compute_correlograms_one_segment_numba( The window size over which to perform the cross-correlation, in samples bin_size : int The size of which to bin lags, in samples. + num_threads : int + The number of threads to use in parallel. """ + numba.set_num_threads(num_threads) start_j = 0 - for i in range(spike_times.size): + for i in numba.prange(spike_times.size): for j in range(start_j, spike_times.size): if i == j: continue diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index 523aa4ba05..9c3a543e59 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -104,6 +104,28 @@ def test_equal_results_correlograms(window_and_bin_ms): assert np.array_equal(result_numpy, result_numba) +@pytest.mark.skipif(not HAVE_NUMBA, reason="Numba not available") +@pytest.mark.parametrize("window_and_bin_ms", [(60.0, 2.0), (3.57, 1.6421)]) +def test_equal_results_fast_correlograms(window_and_bin_ms): + """ + Test that the 2 methods have same results with some varied time bins + that are not tested in other tests. + """ + + window_ms, bin_ms = window_and_bin_ms + sorting = generate_sorting(num_units=5, sampling_frequency=30000.0, durations=[10.325, 3.5], seed=0) + + result_numba_fast, bins_numpy = _compute_correlograms_on_sorting( + sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba", fast_mode=True + ) + result_numba, bins_numba = _compute_correlograms_on_sorting( + sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba", fast_mode=False + ) + from numpy.testing import assert_almost_equal + + assert_almost_equal(result_numba_fast, result_numba) + + @pytest.mark.parametrize("method", ["numpy", param("numba", marks=SKIP_NUMBA)]) def test_flat_cross_correlogram(method): """