diff --git a/src/spikeinterface/postprocessing/__init__.py b/src/spikeinterface/postprocessing/__init__.py index dca9711ccd..fe9dabb727 100644 --- a/src/spikeinterface/postprocessing/__init__.py +++ b/src/spikeinterface/postprocessing/__init__.py @@ -15,9 +15,12 @@ from .correlograms import ( ComputeACG3D, ComputeCorrelograms, + ComputeAutoCorrelograms, compute_acgs_3d, compute_correlograms, + compute_auto_correlograms, correlogram_for_one_segment, + auto_correlogram_for_one_segment, ) from .isi import ( diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index ce3d1cd4a9..8d303296df 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -206,8 +206,121 @@ def _get_data(self): return self.data["ccgs"], self.data["bins"] +class ComputeAutoCorrelograms(AnalyzerExtension): + """ + Compute only auto correlograms of unit spike times. + + Parameters + ---------- + window_ms : float, default: 50.0 + The window around the spike to compute the correlation in ms. For example, + if 50 ms, the correlations will be computed at lags -25 ms ... 25 ms. + bin_ms : float, default: 1.0 + The bin size in ms. This determines the bin size over which to + combine lags. For example, with a window size of -25 ms to 25 ms, and + bin size 1 ms, the correlation will be binned as -25 ms, -24 ms, ... + method : "auto" | "numpy" | "numba", default: "auto" + If "auto" and numba is installed, numba is used, otherwise numpy is used. + + Returns + ------- + correlogram : np.array + Auto Correlograms with shape (num_units, num_bins) + bins : np.array + The bin edges in ms + + Notes + ----- + In the extracellular electrophysiology context, a correlogram + is a visualisation of the results of a cross-correlation + between two spike trains. The cross-correlation slides one spike train + along another sample-by-sample, taking the correlation at each 'lag'. This results + in a plot with 'lag' (i.e. time offset) on the x-axis and 'correlation' + (i.e. how similar to two spike trains are) on the y-axis. In this + implementation, the y-axis result is the 'counts' of spike matches per + time bin (rather than a computer correlation or covariance). + + In the present implementation, a 'window' around spikes is first + specified. For example, if a window of 100 ms is taken, we will + take the correlation at lags from -50 ms to +50 ms around the spike peak. + In theory, we can have as many lags as we have samples. Often, this + visualisation is too high resolution and instead the lags are binned + (e.g. -50 to -45 ms, ..., -5 to 0 ms, 0 to 5 ms, ...., 45 to 50 ms). + When using counts as output, binning the lags involves adding up all counts across + a range of lags. + + + """ + + extension_name = "auto_correlograms" + depend_on = [] + need_recording = False + use_nodepipeline = False + need_job_kwargs = False + + def _set_params(self, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto"): + params = dict(window_ms=window_ms, bin_ms=bin_ms, method=method) + return params + + def _select_extension_data(self, unit_ids): + # filter metrics dataframe + unit_indices = self.sorting_analyzer.sorting.ids_to_indices(unit_ids) + new_acgs = self.data["acgs"][unit_indices] + new_bins = self.data["bins"] + new_data = dict(ccgs=new_acgs, bins=new_bins) + return new_data + + def _merge_extension_data( + self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, censor_ms=None, verbose=False, **job_kwargs + ): + """ + When two units are merged, their cross-correlograms with other units become the sum + of the previous cross-correlograms. More precisely, if units i and j get merged into + unit k, then the new unit's cross-correlogram with any other unit l is: + C_{k,l} = C_{i,l} + C_{j,l} + C_{l,k} = C_{l,k} + C_{l,j} + Here, we apply this formula to quickly compute correlograms for merged units. + """ + + new_bins = self.data["bins"] + all_new_units = new_sorting_analyzer.unit_ids + arr = self.data["acgs"] + + # compute all new isi at once + new_sorting = new_sorting_analyzer.sorting.select_units(new_unit_ids) + only_new_acgs, new_bins = _compute_auto_correlograms_on_sorting(new_sorting, **self.params) + new_acgs = np.zeros((len(all_new_units), only_new_acgs.shape[1]), dtype=np.int64) + + for unit_ind, unit_id in enumerate(all_new_units): + if unit_id not in new_unit_ids: + keep_unit_index = self.sorting_analyzer.sorting.id_to_index(unit_id) + new_acgs[unit_ind, :] = arr[keep_unit_index, :] + else: + new_unit_index = new_sorting.id_to_index(unit_id) + new_acgs[unit_ind, :] = only_new_acgs[new_unit_index, :] + + new_data = dict(acgs=new_acgs, bins=new_bins) + return new_data + + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + # TODO: for now we just copy + new_acgs, new_bins = _compute_auto_correlograms_on_sorting(new_sorting_analyzer.sorting, **self.params) + new_data = dict(acgs=new_acgs, bins=new_bins) + return new_data + + def _run(self, verbose=False): + acgs, bins = _compute_auto_correlograms_on_sorting(self.sorting_analyzer.sorting, **self.params) + self.data["acgs"] = acgs + self.data["bins"] = bins + + def _get_data(self): + return self.data["acgs"], self.data["bins"] + + register_result_extension(ComputeCorrelograms) +register_result_extension(ComputeAutoCorrelograms) compute_correlograms_sorting_analyzer = ComputeCorrelograms.function_factory() +compute_auto_correlograms_sorting_analyzer = ComputeAutoCorrelograms.function_factory() def compute_correlograms( @@ -233,9 +346,6 @@ def compute_correlograms( ) -compute_correlograms.__doc__ = compute_correlograms_sorting_analyzer.__doc__ - - def _make_bins(sorting, window_ms, bin_ms) -> tuple[np.ndarray, int, int]: """ Create the bins for the correlogram, in samples. @@ -295,7 +405,6 @@ def _compute_num_bins(window_size, bin_size): """ num_half_bins = int(window_size // bin_size) num_bins = int(2 * num_half_bins) - return num_bins, num_half_bins @@ -603,6 +712,333 @@ def _compute_correlograms_one_segment_numba( correlograms[spike_unit_indices[i], spike_unit_indices[j], num_half_bins + bin] += 1 + @numba.jit( + nopython=True, + nogil=True, + cache=False, + ) + def _compute_auto_correlograms_one_segment_numba( + correlograms, spike_times, spike_unit_indices, window_size, bin_size, num_half_bins + ): + """ + Compute the correlograms using `numba` for speed. + + The algorithm works by brute-force iteration through all + pairs of spikes (skipping those when outside of the window). + The spike-time difference and its time bin are computed + and stored in a (num_units, num_units, num_bins) + correlogram. The correlogram must be passed as an + argument and is filled in-place. + + Parameters + --------- + + correlograms: np.array + A (num_units, num_bins) array of auto_correlograms + between all units at each lag time bin. This is passed + as counts for all segments are added to it. + spike_times : np.ndarray + An array of spike times (in samples, not seconds). + This contains spikes from all units. + spike_unit_indices : np.ndarray + An array of labels indicating the unit of the corresponding + spike in `spike_times`. + window_size : int + The window size over which to perform the cross-correlation, in samples + bin_size : int + The size of which to bin lags, in samples. + """ + start_j = 0 + for i in range(spike_times.size): + for j in range(start_j, spike_times.size): + if i == j: + continue + + if spike_unit_indices[i] != spike_unit_indices[j]: + continue + + diff = spike_times[i] - spike_times[j] + + # When the diff is exactly the window size, keep going + # without iterating start_j in case this spike also has + # other diffs with other units that == window size. + if diff == window_size: + continue + + # if the time of spike i is more than window size later than + # spike j, then spike i + 1 will also be more than a window size + # later than spike j. Iterate the start_j and check the next spike. + if diff > window_size: + start_j += 1 + continue + + # If the time of spike i is more than a window size earlier + # than spike j, then all following j spikes will be even later + # i spikes and so all more than a window size earlier. So move + # onto the next i. + if diff < -window_size: + break + + bin = diff // bin_size + + correlograms[spike_unit_indices[i], num_half_bins + bin] += 1 + + +###### ACG area ###### + + +def compute_auto_correlograms( + sorting_analyzer_or_sorting, + window_ms: float = 50.0, + bin_ms: float = 1.0, + method: str = "auto", +): + """ + Compute correlograms using Numba or Numpy. + See ComputeCorrelograms() for details. + """ + if isinstance(sorting_analyzer_or_sorting, MockWaveformExtractor): + sorting_analyzer_or_sorting = sorting_analyzer_or_sorting.sorting + + if isinstance(sorting_analyzer_or_sorting, SortingAnalyzer): + return _compute_auto_correlograms_sorting_analyzer( + sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method + ) + else: + return _compute_auto_correlograms_on_sorting( + sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method + ) + + +def _compute_auto_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto"): + """ + Computes auto-correlograms from multiple units. + + Entry function to compute correlograms across all units in a `Sorting` + object (i.e. spike trains at all determined offsets will be computed + for each unit against every other unit). + + Parameters + ---------- + sorting : Sorting + A SpikeInterface Sorting object + window_ms : float + The window size over which to perform the cross-correlation, in ms + bin_ms : float + The size of which to bin lags, in ms. + method : str + To use "numpy" or "numba". "auto" will use numba if available, + otherwise numpy. + + Returns + ------- + correlograms : np.array + A (num_units, num_bins) array where unit x unit correlation + matrices are stacked at all determined time bins. Note the true + correlation is not returned but instead the count of number of matches. + bins : np.array + The bins edges in ms + """ + assert method in ("auto", "numba", "numpy"), "method must be 'auto', 'numba' or 'numpy'" + + if method == "auto": + method = "numba" if HAVE_NUMBA else "numpy" + + bins, window_size, bin_size = _make_bins(sorting, window_ms, bin_ms) + + if method == "numpy": + correlograms = _compute_auto_correlograms_numpy(sorting, window_size, bin_size) + if method == "numba": + correlograms = _compute_auto_correlograms_numba(sorting, window_size, bin_size) + + return correlograms, bins + + +# LOW-LEVEL IMPLEMENTATIONS +def _compute_auto_correlograms_numpy(sorting, window_size, bin_size): + """ + Computes auto-correlograms for all units in a sorting object. + + This very elegant implementation is copied from phy package written by Cyrille Rossant. + https://github.com/cortex-lab/phylib/blob/master/phylib/stats/ccg.py + + The main modification is the way positive and negative are handled + explicitly for rounding reasons. + + Other slight modifications have been made to fit the SpikeInterface + data model (e.g. adding the ability to handle multiple segments). + + Adaptation: Samuel Garcia + """ + num_seg = sorting.get_num_segments() + num_units = len(sorting.unit_ids) + spikes = sorting.to_spike_vector(concatenated=False) + + num_bins, num_half_bins = _compute_num_bins(window_size, bin_size) + + correlograms = np.zeros((num_units, num_bins), dtype="int64") + + for seg_index in range(num_seg): + spike_times = spikes[seg_index]["sample_index"] + spike_unit_indices = spikes[seg_index]["unit_index"] + + c0 = auto_correlogram_for_one_segment(spike_times, spike_unit_indices, window_size, bin_size) + + correlograms += c0 + + return correlograms + + +def auto_correlogram_for_one_segment(spike_times, spike_unit_indices, window_size, bin_size): + """ + A very well optimized algorithm for the auto-correlation of + spike trains, copied from the Phy package, written by Cyrille Rossant. + + Parameters + ---------- + spike_times : np.ndarray + An array of spike times (in samples, not seconds). + This contains spikes from all units. + spike_unit_indices : np.ndarray + An array of labels indicating the unit of the corresponding + spike in `spike_times`. + window_size : int + The window size over which to perform the cross-correlation, in samples + bin_size : int + The size of which to bin lags, in samples. + + Returns + ------- + correlograms : np.array + A (num_units, num_bins) array of correlograms + between all units at each lag time bin. + + Notes + ----- + For all spikes, time difference between this spike and + every other spike within the window is directly computed + and stored as a count in the relevant lag time bin. + + Initially, the spike_times array is shifted by 1 position, and the difference + computed. This gives the time differences between the closest spikes + (skipping the zero-lag case). Next, the differences between + spikes times in samples are converted into units relative to + bin_size ('binarized'). Spikes in which the binarized difference to + their closest neighbouring spike is greater than half the bin-size are + masked. + + Finally, the indices of the (num_units, num_units, num_bins) correlogram + that need incrementing are done so with `ravel_multi_index()`. This repeats + for all shifts along the spike_train until no spikes have a corresponding + match within the window size. + """ + num_bins, num_half_bins = _compute_num_bins(window_size, bin_size) + num_units = len(np.unique(spike_unit_indices)) + + correlograms = np.zeros((num_units, num_bins), dtype="int64") + + for unit_ind in range(num_units): + unit_mask = spike_unit_indices == unit_ind + spike_times_unit = spike_times[unit_mask] + + # At a given shift, the mask precises which spikes have matching spikes + # within the correlogram time window. + mask = np.ones_like(spike_times_unit, dtype="bool") + + # The loop continues as long as there is at least one + # spike with a matching spike. + shift = 1 + while mask[:-shift].any(): + # Number of time samples between spike i and spike i+shift. + spike_diff = spike_times_unit[shift:] - spike_times_unit[:-shift] + for sign in (-1, 1): + # Binarize the delays between spike i and spike i+shift for negative and positive + # the operator // is np.floor_divide + spike_diff_b = (spike_diff * sign) // bin_size + + # Spikes with no matching spikes are masked. + if sign == -1: + mask[:-shift][spike_diff_b < -num_half_bins] = False + else: + mask[:-shift][spike_diff_b >= num_half_bins] = False + + m = mask[:-shift] + + # Find the indices in the raveled correlograms array that need + # to be incremented, taking into account the spike unit labels. + indices = spike_diff_b[m] + num_half_bins + + # Increment the matching spikes in the correlograms array. + bbins = np.bincount(indices) + correlograms[unit_ind, : len(bbins)] += bbins + + if sign == 1: + # For positive sign, the end bin is < num_half_bins (e.g. + # bin = 29, num_half_bins = 30, will go to index 59 (i.e. the + # last bin). For negative sign, the first bin is == num_half_bins + # e.g. bin = -30, with num_half_bins = 30 will go to bin 0. Therefore + # sign == 1 must mask spike_diff_b <= num_half_bins but sign == -1 + # must count all (possibly repeating across units) cases of + # spike_diff_b == num_half_bins. So we turn it back on here + # for the next loop that starts with the -1 case. + mask[:-shift][spike_diff_b == num_half_bins] = True + + shift += 1 + + return correlograms + + +def _compute_auto_correlograms_numba(sorting, window_size, bin_size): + """ + Computes auto-correlograms between all units in `sorting`. + + This is a "brute force" method using compiled code (numba) + to accelerate the computation. See + `_compute_auto_correlograms_one_segment_numba()` for details. + + Parameters + ---------- + sorting : Sorting + A SpikeInterface Sorting object + window_size : int + The window size over which to perform the cross-correlation, in samples + bin_size : int + The size of which to bin lags, in samples. + + Returns + ------- + correlograms: np.array + A (num_units, num_units, num_bins) array of correlograms + between all units at each lag time bin. + + Implementation: Aurélien Wyngaard + """ + assert HAVE_NUMBA, "numba version of this function requires installation of numba" + + num_bins, num_half_bins = _compute_num_bins(window_size, bin_size) + num_units = len(sorting.unit_ids) + + spikes = sorting.to_spike_vector(concatenated=False) + correlograms = np.zeros((num_units, num_bins), dtype=np.int64) + + for seg_index in range(sorting.get_num_segments()): + spike_times = spikes[seg_index]["sample_index"] + spike_unit_indices = spikes[seg_index]["unit_index"] + + _compute_auto_correlograms_one_segment_numba( + correlograms, + spike_times.astype(np.int64, copy=False), + spike_unit_indices.astype(np.int32, copy=False), + window_size, + bin_size, + num_half_bins, + ) + + return correlograms + + +###### 3D ACG area ###### + class ComputeACG3D(AnalyzerExtension): """ @@ -1014,3 +1450,5 @@ def compute_acgs_3d( compute_acgs_3d.__doc__ = compute_acgs_3d_sorting_analyzer.__doc__ +compute_correlograms.__doc__ = compute_correlograms_sorting_analyzer.__doc__ +compute_auto_correlograms.__doc__ = compute_auto_correlograms_sorting_analyzer.__doc__ diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index 523aa4ba05..569d6c5d80 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -18,13 +18,15 @@ from pytest import param from spikeinterface import NumpySorting, generate_sorting -from spikeinterface.postprocessing import ComputeACG3D, ComputeCorrelograms +from spikeinterface.postprocessing import ComputeACG3D, ComputeCorrelograms, ComputeAutoCorrelograms from spikeinterface.postprocessing.correlograms import ( _compute_3d_acg_one_unit, _compute_correlograms_on_sorting, + _compute_auto_correlograms_on_sorting, _make_bins, compute_acgs_3d, compute_correlograms, + compute_auto_correlograms, ) from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite @@ -55,13 +57,42 @@ def test_sortinganalyzer_correlograms(self, method): params = dict(method=method, window_ms=100, bin_ms=6.5) ext_numpy = sorting_analyzer.compute(ComputeCorrelograms.extension_name, **params) - result_sorting, bins_sorting = compute_correlograms(self.sorting, **params) assert np.array_equal(result_sorting, ext_numpy.data["ccgs"]) assert np.array_equal(bins_sorting, ext_numpy.data["bins"]) +class TestComputeAutoCorrelograms(AnalyzerExtensionCommonTestSuite): + @pytest.mark.parametrize( + "params", + [ + dict(method="numpy"), + dict(method="auto"), + param(dict(method="numba"), marks=SKIP_NUMBA), + ], + ) + def test_extension(self, params): + self.run_extension_tests(ComputeAutoCorrelograms, params) + + @pytest.mark.parametrize("method", ["numpy", param("numba", marks=SKIP_NUMBA)]) + def test_sortinganalyzer_auto_correlograms(self, method): + """ + Test the outputs when using SortingAnalyzer against + the output passing sorting directly to `compute_auto_correlograms`. + Sorting to `compute_auto_correlograms` is tested extensively below + so if these match it means `SortingAnalyzer` is working. + """ + sorting_analyzer = self._prepare_sorting_analyzer("memory", sparse=False, extension_class=ComputeCorrelograms) + + params = dict(method=method, window_ms=100, bin_ms=6.5) + ext_numpy = sorting_analyzer.compute(ComputeAutoCorrelograms.extension_name, **params) + result_sorting, bins_sorting = compute_auto_correlograms(self.sorting, **params) + + assert np.array_equal(result_sorting, ext_numpy.data["acgs"]) + assert np.array_equal(bins_sorting, ext_numpy.data["bins"]) + + # Unit Tests ############ def test_make_bins(): @@ -104,6 +135,48 @@ def test_equal_results_correlograms(window_and_bin_ms): assert np.array_equal(result_numpy, result_numba) +@pytest.mark.skipif(not HAVE_NUMBA, reason="Numba not available") +@pytest.mark.parametrize("window_and_bin_ms", [(60.0, 2.0), (3.57, 1.6421)]) +def test_equal_results_auto_correlograms(window_and_bin_ms): + """ + Test that the 2 methods have same results with some varied time bins + that are not tested in other tests. + """ + + window_ms, bin_ms = window_and_bin_ms + sorting = generate_sorting(num_units=5, sampling_frequency=30000.0, durations=[10.325, 3.5], seed=0) + + result_numpy, bins_numpy = _compute_auto_correlograms_on_sorting( + sorting, window_ms=window_ms, bin_ms=bin_ms, method="numpy" + ) + result_numba, bins_numba = _compute_auto_correlograms_on_sorting( + sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba" + ) + + assert np.array_equal(result_numpy, result_numba) + + +@pytest.mark.skipif(not HAVE_NUMBA, reason="Numba not available") +@pytest.mark.parametrize("window_and_bin_ms", [(60.0, 2.0), (3.57, 1.6421)]) +def test_equal_results_auto_correlograms(window_and_bin_ms): + """ + Test that the 2 methods have same results with some varied time bins + that are not tested in other tests. + """ + + window_ms, bin_ms = window_and_bin_ms + sorting = generate_sorting(num_units=5, sampling_frequency=30000.0, durations=[10.325, 3.5], seed=0) + + result_numpy, bins_numpy = _compute_auto_correlograms_on_sorting( + sorting, window_ms=window_ms, bin_ms=bin_ms, method="numpy" + ) + result_numba, bins_numba = _compute_auto_correlograms_on_sorting( + sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba" + ) + + assert np.array_equal(result_numpy, result_numba) + + @pytest.mark.parametrize("method", ["numpy", param("numba", marks=SKIP_NUMBA)]) def test_flat_cross_correlogram(method): """ @@ -263,6 +336,38 @@ def test_compute_correlograms_different_units(method): assert np.array_equal(result[0, 1], np.array([1, 0, 1, 1, 1, 0, 0, 0])) +@pytest.mark.parametrize("method", ["numpy", param("numba", marks=SKIP_NUMBA)]) +def test_compute_auto_correlograms_different_units(method): + """ + Make a supplementary test to `test_compute_correlograms` in which all + units had the same spike train. Test here a simpler and accessible + test case with only two neurons with different spike time differences + within and across units. + + This case is simple enough to validate by hand, for example for the + result[1, 1] case we are looking at the autocorrelogram of the unit '1'. + The spike times are 4 and 16 s, therefore we expect to see a count in + the +/- 10 to 15 s bin. + """ + sampling_frequency = 30000 + spike_times = np.array([0, 4, 8, 16]) / 1000 * sampling_frequency + spike_times.astype(int) + + spike_unit_indices = np.array([0, 1, 0, 1]) + + window_ms = 40 + bin_ms = 5 + + sorting = NumpySorting.from_samples_and_labels( + samples_list=[spike_times], labels_list=[spike_unit_indices], sampling_frequency=sampling_frequency + ) + + result, bins = compute_auto_correlograms(sorting, window_ms=window_ms, bin_ms=bin_ms, method=method) + + assert np.array_equal(result[0], np.array([0, 0, 1, 0, 0, 1, 0, 0])) + assert np.array_equal(result[1], np.array([0, 1, 0, 0, 0, 0, 1, 0])) + + def generate_correlogram_test_dataset(sampling_frequency, fill_all_bins, hit_bin_edge): """ This generates a detailed correlogram test and expected outputs, for a number of