Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 33 additions & 13 deletions src/spikeinterface/postprocessing/correlograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ class ComputeCorrelograms(AnalyzerExtension):
use_nodepipeline = False
need_job_kwargs = False

def _set_params(self, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto"):
params = dict(window_ms=window_ms, bin_ms=bin_ms, method=method)
def _set_params(self, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto", fast_mode: bool = False):
params = dict(window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode)

return params

Expand Down Expand Up @@ -215,6 +215,7 @@ def compute_correlograms(
window_ms: float = 50.0,
bin_ms: float = 1.0,
method: str = "auto",
fast_mode: bool = False,
):
"""
Compute correlograms using Numba or Numpy.
Expand All @@ -225,11 +226,11 @@ def compute_correlograms(

if isinstance(sorting_analyzer_or_sorting, SortingAnalyzer):
return compute_correlograms_sorting_analyzer(
sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method
sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode
)
else:
return _compute_correlograms_on_sorting(
sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method
sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode
)


Expand Down Expand Up @@ -299,7 +300,7 @@ def _compute_num_bins(window_size, bin_size):
return num_bins, num_half_bins


def _compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto"):
def _compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto", fast_mode=False):
"""
Computes cross-correlograms from multiple units.

Expand All @@ -318,6 +319,9 @@ def _compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto"):
method : str
To use "numpy" or "numba". "auto" will use numba if available,
otherwise numpy.
fast_mode : bool
If True, use faster implementations (currently only if method is 'numba'),
at the cost of possible minor numerical differences.

Returns
-------
Expand All @@ -336,15 +340,15 @@ def _compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto"):
bins, window_size, bin_size = _make_bins(sorting, window_ms, bin_ms)

if method == "numpy":
correlograms = _compute_correlograms_numpy(sorting, window_size, bin_size)
correlograms = _compute_correlograms_numpy(sorting, window_size, bin_size, fast_mode=fast_mode)
if method == "numba":
correlograms = _compute_correlograms_numba(sorting, window_size, bin_size)
correlograms = _compute_correlograms_numba(sorting, window_size, bin_size, fast_mode=fast_mode)

return correlograms, bins


# LOW-LEVEL IMPLEMENTATIONS
def _compute_correlograms_numpy(sorting, window_size, bin_size):
def _compute_correlograms_numpy(sorting, window_size, bin_size, fast_mode):
"""
Computes correlograms for all units in a sorting object.

Expand All @@ -371,14 +375,14 @@ def _compute_correlograms_numpy(sorting, window_size, bin_size):
spike_times = spikes[seg_index]["sample_index"]
spike_unit_indices = spikes[seg_index]["unit_index"]

c0 = correlogram_for_one_segment(spike_times, spike_unit_indices, window_size, bin_size)
c0 = correlogram_for_one_segment(spike_times, spike_unit_indices, window_size, bin_size, fast_mode)

correlograms += c0

return correlograms


def correlogram_for_one_segment(spike_times, spike_unit_indices, window_size, bin_size):
def correlogram_for_one_segment(spike_times, spike_unit_indices, window_size, bin_size, fast_mode=False):
"""
A very well optimized algorithm for the cross-correlation of
spike trains, copied from the Phy package, written by Cyrille Rossant.
Expand All @@ -395,6 +399,9 @@ def correlogram_for_one_segment(spike_times, spike_unit_indices, window_size, bi
The window size over which to perform the cross-correlation, in samples
bin_size : int
The size of which to bin lags, in samples.
fast_mode : bool, default: False
If True, use faster implementations (currently only if method is 'numba'),
at the cost of possible minor numerical differences.

Returns
-------
Expand Down Expand Up @@ -483,7 +490,7 @@ def correlogram_for_one_segment(spike_times, spike_unit_indices, window_size, bi
return correlograms


def _compute_correlograms_numba(sorting, window_size, bin_size):
def _compute_correlograms_numba(sorting, window_size, bin_size, fast_mode):
"""
Computes cross-correlograms between all units in `sorting`.

Expand All @@ -499,6 +506,9 @@ def _compute_correlograms_numba(sorting, window_size, bin_size):
The window size over which to perform the cross-correlation, in samples
bin_size : int
The size of which to bin lags, in samples.
fast_mode : bool
If True, use faster implementations (currently only if method is 'numba'),
at the cost of possible minor numerical differences.

Returns
-------
Expand All @@ -516,6 +526,11 @@ def _compute_correlograms_numba(sorting, window_size, bin_size):
spikes = sorting.to_spike_vector(concatenated=False)
correlograms = np.zeros((num_units, num_units, num_bins), dtype=np.int64)

if fast_mode:
num_threads = mp.cpu_count()
else:
num_threads = 1

for seg_index in range(sorting.get_num_segments()):
spike_times = spikes[seg_index]["sample_index"]
spike_unit_indices = spikes[seg_index]["unit_index"]
Expand All @@ -527,6 +542,7 @@ def _compute_correlograms_numba(sorting, window_size, bin_size):
window_size,
bin_size,
num_half_bins,
num_threads,
)

return correlograms
Expand All @@ -539,9 +555,10 @@ def _compute_correlograms_numba(sorting, window_size, bin_size):
nopython=True,
nogil=True,
cache=False,
parallel=True,
)
def _compute_correlograms_one_segment_numba(
correlograms, spike_times, spike_unit_indices, window_size, bin_size, num_half_bins
correlograms, spike_times, spike_unit_indices, window_size, bin_size, num_half_bins, num_threads
):
"""
Compute the correlograms using `numba` for speed.
Expand Down Expand Up @@ -570,9 +587,12 @@ def _compute_correlograms_one_segment_numba(
The window size over which to perform the cross-correlation, in samples
bin_size : int
The size of which to bin lags, in samples.
num_threads : int
The number of threads to use in parallel.
"""
numba.set_num_threads(num_threads)
start_j = 0
for i in range(spike_times.size):
for i in numba.prange(spike_times.size):
for j in range(start_j, spike_times.size):
if i == j:
continue
Expand Down
22 changes: 22 additions & 0 deletions src/spikeinterface/postprocessing/tests/test_correlograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,28 @@ def test_equal_results_correlograms(window_and_bin_ms):
assert np.array_equal(result_numpy, result_numba)


@pytest.mark.skipif(not HAVE_NUMBA, reason="Numba not available")
@pytest.mark.parametrize("window_and_bin_ms", [(60.0, 2.0), (3.57, 1.6421)])
def test_equal_results_fast_correlograms(window_and_bin_ms):
"""
Test that the 2 methods have same results with some varied time bins
that are not tested in other tests.
"""

window_ms, bin_ms = window_and_bin_ms
sorting = generate_sorting(num_units=5, sampling_frequency=30000.0, durations=[10.325, 3.5], seed=0)

result_numba_fast, bins_numpy = _compute_correlograms_on_sorting(
sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba", fast_mode=True
)
result_numba, bins_numba = _compute_correlograms_on_sorting(
sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba", fast_mode=False
)
from numpy.testing import assert_almost_equal

assert_almost_equal(result_numba_fast, result_numba)


@pytest.mark.parametrize("method", ["numpy", param("numba", marks=SKIP_NUMBA)])
def test_flat_cross_correlogram(method):
"""
Expand Down